[llvm] [mlir] [mlir][Python] create MLIRPythonSupport (PR #171775)
Maksim Levental via llvm-commits
llvm-commits at lists.llvm.org
Sat Dec 27 11:36:35 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/171775
>From 2ea6d2b37b0e67281ea0313db66c80e6abf38bdf Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 10 Dec 2025 23:57:13 -0800
Subject: [PATCH 01/27] [mlir][Python] create MLIRPythonSupport
---
mlir/python/CMakeLists.txt | 65 ++++++++++++++++++++++++++++++--------
1 file changed, 52 insertions(+), 13 deletions(-)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 1e9f1e11d4d06..e9b1aff0455e6 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -3,6 +3,8 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `MLIR_PYTHON_PACKAGE_PREFIX.`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
+set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
+set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
################################################################################
# Structural groupings.
@@ -524,27 +526,17 @@ declare_mlir_dialect_python_bindings(
# dependencies.
################################################################################
-set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
declare_mlir_python_extension(MLIRPythonExtension.Core
MODULE_NAME _mlir
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
MainModule.cpp
- IRAffine.cpp
- IRAttributes.cpp
- IRCore.cpp
- IRInterfaces.cpp
- IRModule.cpp
- IRTypes.cpp
Pass.cpp
Rewrite.cpp
# Headers must be included explicitly so they are installed.
- Globals.h
- IRModule.h
Pass.h
- NanobindUtils.h
Rewrite.h
PRIVATE_LINK_LIBS
LLVMSupport
@@ -752,8 +744,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Nanobind
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectSMT.cpp
- # Headers must be included explicitly so they are installed.
- NanobindUtils.h
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
@@ -860,7 +850,6 @@ endif()
# once ready.
################################################################################
-set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
add_mlir_python_common_capi_library(MLIRPythonCAPI
INSTALL_COMPONENT MLIRPythonModules
INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
@@ -997,3 +986,53 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}")
endif()
endif()
+
+get_property(NB_LIBRARY_TARGET_NAME TARGET MLIRPythonModules.extension._mlir.dso PROPERTY LINK_LIBRARIES)
+list(GET NB_LIBRARY_TARGET_NAME 0 NB_LIBRARY_TARGET_NAME)
+add_mlir_library_install(${NB_LIBRARY_TARGET_NAME})
+add_mlir_library(MLIRPythonSupport
+ ${PYTHON_SOURCE_DIR}/Globals.cpp
+ ${PYTHON_SOURCE_DIR}/IRAffine.cpp
+ ${PYTHON_SOURCE_DIR}/IRAttributes.cpp
+ ${PYTHON_SOURCE_DIR}/IRCore.cpp
+ ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp
+ ${PYTHON_SOURCE_DIR}/IRTypes.cpp
+ EXCLUDE_FROM_LIBMLIR
+ SHARED
+ LINK_COMPONENTS
+ Support
+ LINK_LIBS
+ ${NB_LIBRARY_TARGET_NAME}
+ MLIRCAPIIR
+)
+target_link_libraries(MLIRPythonSupport PUBLIC ${NB_LIBRARY_TARGET_NAME})
+nanobind_link_options(MLIRPythonSupport)
+set_target_properties(MLIRPythonSupport PROPERTIES
+ LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ # Needed for windows (and doesn't hurt others).
+ RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+)
+set(eh_rtti_enable)
+if(MSVC)
+ set(eh_rtti_enable /EHsc /GR)
+elseif(LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
+ set(eh_rtti_enable -frtti -fexceptions)
+endif()
+target_compile_options(MLIRPythonSupport PRIVATE ${eh_rtti_enable})
+if(APPLE)
+ # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind
+ # doesn't declare this API as undefined in its linker flags. So we need to declare it as such
+ # for downstream users that do not do something like `-undefined dynamic_lookup`.
+ # Same for the rest.
+ target_link_options(MLIRPythonSupport PUBLIC
+ "LINKER:-U,_PyClassMethod_New"
+ "LINKER:-U,_PyCode_Addr2Location"
+ "LINKER:-U,_PyFrame_GetLasti"
+ )
+endif()
+target_link_libraries(
+ MLIRPythonModules.extension._mlir.dso
+ PUBLIC MLIRPythonSupport)
+
>From 243f5190d7a03ad2698d9326acca3cb2cf5c5209 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 11 Dec 2025 13:19:23 -0800
Subject: [PATCH 02/27] kind of working
---
mlir/cmake/modules/AddMLIRPython.cmake | 1 +
.../mlir}/Bindings/Python/Globals.h | 48 +-
.../mlir/Bindings/Python/IRCore.h} | 1025 ++++-
.../mlir}/Bindings/Python/NanobindUtils.h | 0
mlir/lib/Bindings/Python/DialectSMT.cpp | 2 +-
.../Python/{IRModule.cpp => Globals.cpp} | 14 +-
mlir/lib/Bindings/Python/IRAffine.cpp | 10 +-
mlir/lib/Bindings/Python/IRAttributes.cpp | 21 +-
mlir/lib/Bindings/Python/IRCore.cpp | 3317 +----------------
mlir/lib/Bindings/Python/IRInterfaces.cpp | 2 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 22 +-
mlir/lib/Bindings/Python/MainModule.cpp | 2277 ++++++++++-
mlir/lib/Bindings/Python/Pass.cpp | 4 +-
mlir/lib/Bindings/Python/Pass.h | 2 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 2 +-
mlir/lib/Bindings/Python/Rewrite.h | 2 +-
mlir/python/CMakeLists.txt | 19 +-
17 files changed, 3428 insertions(+), 3340 deletions(-)
rename mlir/{lib => include/mlir}/Bindings/Python/Globals.h (82%)
rename mlir/{lib/Bindings/Python/IRModule.h => include/mlir/Bindings/Python/IRCore.h} (57%)
rename mlir/{lib => include/mlir}/Bindings/Python/NanobindUtils.h (100%)
rename mlir/lib/Bindings/Python/{IRModule.cpp => Globals.cpp} (97%)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index ca90151e76268..882781736b493 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -764,6 +764,7 @@ function(add_mlir_python_extension libname extname)
nanobind_add_module(${libname}
NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
FREE_THREADED
+ NB_SHARED
${ARG_SOURCES}
)
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
similarity index 82%
rename from mlir/lib/Bindings/Python/Globals.h
rename to mlir/include/mlir/Bindings/Python/Globals.h
index 1e81f53e465ac..fea7a201453ce 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -15,10 +15,12 @@
#include <unordered_set>
#include <vector>
-#include "NanobindUtils.h"
+#include "mlir-c/Debug.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir/CAPI/Support.h"
+
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
@@ -200,6 +202,50 @@ class PyGlobals {
TypeIDAllocator typeIDAllocator;
};
+/// Wrapper for the global LLVM debugging flag.
+struct PyGlobalDebugFlag {
+ static void set(nanobind::object &o, bool enable) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+ }
+
+ static bool get(const nanobind::object &) {
+ nanobind::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+ }
+
+ static void bind(nanobind::module_ &m) {
+ // Debug flags.
+ nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+ .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
+ .def_static(
+ "set_types",
+ [](const std::string &type) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugType(type.c_str());
+ },
+ nanobind::arg("types"),
+ "Sets specific debug types to be produced by LLVM.")
+ .def_static(
+ "set_types",
+ [](const std::vector<std::string> &types) {
+ std::vector<const char *> pointers;
+ pointers.reserve(types.size());
+ for (const std::string &str : types)
+ pointers.push_back(str.c_str());
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
+ },
+ nanobind::arg("types"),
+ "Sets multiple specific debug types to be produced by LLVM.");
+ }
+
+private:
+ static nanobind::ft_mutex mutex;
+};
+
+
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/include/mlir/Bindings/Python/IRCore.h
similarity index 57%
rename from mlir/lib/Bindings/Python/IRModule.h
rename to mlir/include/mlir/Bindings/Python/IRCore.h
index e706be3b4d32a..488196ea42e44 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1,4 +1,4 @@
-//===- IRModules.h - IR Submodules of pybind module -----------------------===//
+//===- IRCore.h - IR helpers of python bindings ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,8 +7,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===----------------------------------------------------------------------===//
-#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
-#define MLIR_BINDINGS_PYTHON_IRMODULES_H
+#ifndef MLIR_BINDINGS_PYTHON_IRCORE_H
+#define MLIR_BINDINGS_PYTHON_IRCORE_H
#include <optional>
#include <sstream>
@@ -20,12 +20,14 @@
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ThreadPool.h"
@@ -1323,12 +1325,1017 @@ struct MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
-void populateIRAffine(nanobind::module_ &m);
-void populateIRAttributes(nanobind::module_ &m);
-void populateIRCore(nanobind::module_ &m);
-void populateIRInterfaces(nanobind::module_ &m);
-void populateIRTypes(nanobind::module_ &m);
+//------------------------------------------------------------------------------
+// Utilities.
+//------------------------------------------------------------------------------
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+static nanobind::object classmethod(Func f, Args... args) {
+ nanobind::object cf = nanobind::cpp_function(f, args...);
+ return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
+}
+
+static nanobind::object
+createCustomDialectWrapper(const std::string &dialectNamespace,
+ nanobind::object dialectDescriptor) {
+ auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
+ if (!dialectClass) {
+ // Use the base class.
+ return nanobind::cast(PyDialect(std::move(dialectDescriptor)));
+ }
+
+ // Create the custom implementation.
+ return (*dialectClass)(std::move(dialectDescriptor));
+}
+
+static MlirStringRef toMlirStringRef(const std::string &s) {
+ return mlirStringRefCreate(s.data(), s.size());
+}
+
+static MlirStringRef toMlirStringRef(std::string_view s) {
+ return mlirStringRefCreate(s.data(), s.size());
+}
+
+static MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
+ return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
+}
+
+/// Create a block, using the current location context if no locations are
+/// specified.
+static MlirBlock
+createBlock(const nanobind::sequence &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs) {
+ SmallVector<MlirType> argTypes;
+ argTypes.reserve(nanobind::len(pyArgTypes));
+ for (const auto &pyType : pyArgTypes)
+ argTypes.push_back(nanobind::cast<PyType &>(pyType));
+
+ SmallVector<MlirLocation> argLocs;
+ if (pyArgLocs) {
+ argLocs.reserve(nanobind::len(*pyArgLocs));
+ for (const auto &pyLoc : *pyArgLocs)
+ argLocs.push_back(nanobind::cast<PyLocation &>(pyLoc));
+ } else if (!argTypes.empty()) {
+ argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
+ }
+
+ if (argTypes.size() != argLocs.size())
+ throw nanobind::value_error(("Expected " + Twine(argTypes.size()) +
+ " locations, got: " + Twine(argLocs.size()))
+ .str()
+ .c_str());
+ return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
+}
+
+struct PyAttrBuilderMap {
+ static bool dunderContains(const std::string &attributeKind) {
+ return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
+ }
+ static nanobind::callable
+ dunderGetItemNamed(const std::string &attributeKind) {
+ auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
+ if (!builder)
+ throw nanobind::key_error(attributeKind.c_str());
+ return *builder;
+ }
+ static void dunderSetItemNamed(const std::string &attributeKind,
+ nanobind::callable func, bool replace) {
+ PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
+ replace);
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyAttrBuilderMap>(m, "AttrBuilder")
+ .def_static("contains", &PyAttrBuilderMap::dunderContains,
+ nanobind::arg("attribute_kind"),
+ "Checks whether an attribute builder is registered for the "
+ "given attribute kind.")
+ .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
+ nanobind::arg("attribute_kind"),
+ "Gets the registered attribute builder for the given "
+ "attribute kind.")
+ .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
+ nanobind::arg("attribute_kind"),
+ nanobind::arg("attr_builder"),
+ nanobind::arg("replace") = false,
+ "Register an attribute builder for building MLIR "
+ "attributes from Python values.");
+ }
+};
+
+//------------------------------------------------------------------------------
+// PyBlock
+//------------------------------------------------------------------------------
+
+inline nanobind::object PyBlock::getCapsule() {
+ return nanobind::steal<nanobind::object>(mlirPythonBlockToCapsule(get()));
+}
+
+//------------------------------------------------------------------------------
+// Collections.
+//------------------------------------------------------------------------------
+
+class PyRegionIterator {
+public:
+ PyRegionIterator(PyOperationRef operation, int nextIndex)
+ : operation(std::move(operation)), nextIndex(nextIndex) {}
+
+ PyRegionIterator &dunderIter() { return *this; }
+
+ PyRegion dunderNext() {
+ operation->checkValid();
+ if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+ throw nanobind::stop_iteration();
+ }
+ MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
+ return PyRegion(operation, region);
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyRegionIterator>(m, "RegionIterator")
+ .def("__iter__", &PyRegionIterator::dunderIter,
+ "Returns an iterator over the regions in the operation.")
+ .def("__next__", &PyRegionIterator::dunderNext,
+ "Returns the next region in the iteration.");
+ }
+
+private:
+ PyOperationRef operation;
+ intptr_t nextIndex = 0;
+};
+
+/// Regions of an op are fixed length and indexed numerically so are represented
+/// with a sequence-like container.
+class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
+public:
+ static constexpr const char *pyClassName = "RegionSequence";
+
+ PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumRegions(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+ PyRegionIterator dunderIter() {
+ operation->checkValid();
+ return PyRegionIterator(operation, startIndex);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__iter__", &PyRegionList::dunderIter,
+ "Returns an iterator over the regions in the sequence.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyRegionList, PyRegion>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumRegions(operation->get());
+ }
+
+ PyRegion getRawElement(intptr_t pos) {
+ operation->checkValid();
+ return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
+ }
+
+ PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyRegionList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+class PyBlockIterator {
+public:
+ PyBlockIterator(PyOperationRef operation, MlirBlock next)
+ : operation(std::move(operation)), next(next) {}
+
+ PyBlockIterator &dunderIter() { return *this; }
+
+ PyBlock dunderNext() {
+ operation->checkValid();
+ if (mlirBlockIsNull(next)) {
+ throw nanobind::stop_iteration();
+ }
+
+ PyBlock returnBlock(operation, next);
+ next = mlirBlockGetNextInRegion(next);
+ return returnBlock;
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyBlockIterator>(m, "BlockIterator")
+ .def("__iter__", &PyBlockIterator::dunderIter,
+ "Returns an iterator over the blocks in the operation's region.")
+ .def("__next__", &PyBlockIterator::dunderNext,
+ "Returns the next block in the iteration.");
+ }
+
+private:
+ PyOperationRef operation;
+ MlirBlock next;
+};
+
+/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
+/// we present them as a more full-featured list-like container but optimize
+/// it for forward iteration. Blocks are always owned by a region.
+class PyBlockList {
+public:
+ PyBlockList(PyOperationRef operation, MlirRegion region)
+ : operation(std::move(operation)), region(region) {}
+
+ PyBlockIterator dunderIter() {
+ operation->checkValid();
+ return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+ }
+
+ intptr_t dunderLen() {
+ operation->checkValid();
+ intptr_t count = 0;
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ count += 1;
+ block = mlirBlockGetNextInRegion(block);
+ }
+ return count;
+ }
+
+ PyBlock dunderGetItem(intptr_t index) {
+ operation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nanobind::index_error("attempt to access out of bounds block");
+ }
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ if (index == 0) {
+ return PyBlock(operation, block);
+ }
+ block = mlirBlockGetNextInRegion(block);
+ index -= 1;
+ }
+ throw nanobind::index_error("attempt to access out of bounds block");
+ }
+
+ PyBlock appendBlock(const nanobind::args &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs) {
+ operation->checkValid();
+ MlirBlock block =
+ createBlock(nanobind::cast<nanobind::sequence>(pyArgTypes), pyArgLocs);
+ mlirRegionAppendOwnedBlock(region, block);
+ return PyBlock(operation, block);
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyBlockList>(m, "BlockList")
+ .def("__getitem__", &PyBlockList::dunderGetItem,
+ "Returns the block at the specified index.")
+ .def("__iter__", &PyBlockList::dunderIter,
+ "Returns an iterator over blocks in the operation's region.")
+ .def("__len__", &PyBlockList::dunderLen,
+ "Returns the number of blocks in the operation's region.")
+ .def("append", &PyBlockList::appendBlock,
+ R"(
+ Appends a new block, with argument types as positional args.
+
+ Returns:
+ The created block.
+ )",
+ nanobind::arg("args"), nanobind::kw_only(),
+ nanobind::arg("arg_locs") = std::nullopt);
+ }
+
+private:
+ PyOperationRef operation;
+ MlirRegion region;
+};
+
+class PyOperationIterator {
+public:
+ PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
+ : parentOperation(std::move(parentOperation)), next(next) {}
+
+ PyOperationIterator &dunderIter() { return *this; }
+
+ nanobind::typed<nanobind::object, PyOpView> dunderNext() {
+ parentOperation->checkValid();
+ if (mlirOperationIsNull(next)) {
+ throw nanobind::stop_iteration();
+ }
+
+ PyOperationRef returnOperation =
+ PyOperation::forOperation(parentOperation->getContext(), next);
+ next = mlirOperationGetNextInBlock(next);
+ return returnOperation->createOpView();
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOperationIterator>(m, "OperationIterator")
+ .def("__iter__", &PyOperationIterator::dunderIter,
+ "Returns an iterator over the operations in an operation's block.")
+ .def("__next__", &PyOperationIterator::dunderNext,
+ "Returns the next operation in the iteration.");
+ }
+
+private:
+ PyOperationRef parentOperation;
+ MlirOperation next;
+};
+
+/// Operations are exposed by the C-API as a forward-only linked list. In
+/// Python, we present them as a more full-featured list-like container but
+/// optimize it for forward iteration. Iterable operations are always owned
+/// by a block.
+class PyOperationList {
+public:
+ PyOperationList(PyOperationRef parentOperation, MlirBlock block)
+ : parentOperation(std::move(parentOperation)), block(block) {}
+
+ PyOperationIterator dunderIter() {
+ parentOperation->checkValid();
+ return PyOperationIterator(parentOperation,
+ mlirBlockGetFirstOperation(block));
+ }
+
+ intptr_t dunderLen() {
+ parentOperation->checkValid();
+ intptr_t count = 0;
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ count += 1;
+ childOp = mlirOperationGetNextInBlock(childOp);
+ }
+ return count;
+ }
+
+ nanobind::typed<nanobind::object, PyOpView> dunderGetItem(intptr_t index) {
+ parentOperation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nanobind::index_error("attempt to access out of bounds operation");
+ }
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ if (index == 0) {
+ return PyOperation::forOperation(parentOperation->getContext(), childOp)
+ ->createOpView();
+ }
+ childOp = mlirOperationGetNextInBlock(childOp);
+ index -= 1;
+ }
+ throw nanobind::index_error("attempt to access out of bounds operation");
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOperationList>(m, "OperationList")
+ .def("__getitem__", &PyOperationList::dunderGetItem,
+ "Returns the operation at the specified index.")
+ .def("__iter__", &PyOperationList::dunderIter,
+ "Returns an iterator over operations in the list.")
+ .def("__len__", &PyOperationList::dunderLen,
+ "Returns the number of operations in the list.");
+ }
+
+private:
+ PyOperationRef parentOperation;
+ MlirBlock block;
+};
+
+class PyOpOperand {
+public:
+ PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
+
+ nanobind::typed<nanobind::object, PyOpView> getOwner() {
+ MlirOperation owner = mlirOpOperandGetOwner(opOperand);
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(owner));
+ return PyOperation::forOperation(context, owner)->createOpView();
+ }
+
+ size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpOperand>(m, "OpOperand")
+ .def_prop_ro("owner", &PyOpOperand::getOwner,
+ "Returns the operation that owns this operand.")
+ .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
+ "Returns the operand number in the owning operation.");
+ }
+
+private:
+ MlirOpOperand opOperand;
+};
+
+class PyOpOperandIterator {
+public:
+ PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
+
+ PyOpOperandIterator &dunderIter() { return *this; }
+
+ PyOpOperand dunderNext() {
+ if (mlirOpOperandIsNull(opOperand))
+ throw nanobind::stop_iteration();
+
+ PyOpOperand returnOpOperand(opOperand);
+ opOperand = mlirOpOperandGetNextUse(opOperand);
+ return returnOpOperand;
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpOperandIterator>(m, "OpOperandIterator")
+ .def("__iter__", &PyOpOperandIterator::dunderIter,
+ "Returns an iterator over operands.")
+ .def("__next__", &PyOpOperandIterator::dunderNext,
+ "Returns the next operand in the iteration.");
+ }
+
+private:
+ MlirOpOperand opOperand;
+};
+
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ // and redefine bindDerived.
+ using ClassTy = nanobind::class_<DerivedTy, PyValue>;
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ PyConcreteValue() = default;
+ PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+ : PyValue(operationRef, value) {}
+ PyConcreteValue(PyValue &orig)
+ : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+ /// Attempts to cast the original value to the derived type and throws on
+ /// type mismatches.
+ static MlirValue castFrom(PyValue &orig) {
+ if (!DerivedTy::isaFunction(orig.get())) {
+ auto origRepr =
+ nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
+ throw nanobind::value_error((Twine("Cannot cast value to ") +
+ DerivedTy::pyClassName + " (from " +
+ origRepr + ")")
+ .str()
+ .c_str());
+ }
+ return orig.get();
+ }
+
+ /// Binds the Python module objects to functions of this class.
+ static void bind(nanobind::module_ &m) {
+ auto cls = ClassTy(
+ m, DerivedTy::pyClassName, nanobind::is_generic(),
+ nanobind::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
+ .str()
+ .c_str()));
+ cls.def(nanobind::init<PyValue &>(), nanobind::keep_alive<0, 1>(),
+ nanobind::arg("value"));
+ cls.def_static(
+ "isinstance",
+ [](PyValue &otherValue) -> bool {
+ return DerivedTy::isaFunction(otherValue);
+ },
+ nanobind::arg("other_value"));
+ cls.def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
+ return self.maybeDownCast();
+ });
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+ static constexpr const char *pyClassName = "OpResult";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResult &self) -> nanobind::typed<nanobind::object, PyOperation> {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in "
+ "the IR");
+ return self.getParentOperation().getObject();
+ },
+ "Returns the operation that produces this result.");
+ c.def_prop_ro(
+ "result_number",
+ [](PyOpResult &self) {
+ return mlirOpResultGetResultNumber(self.get());
+ },
+ "Returns the position of this result in the operation's result list.");
+ }
+};
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<nanobind::typed<nanobind::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+ std::vector<nanobind::typed<nanobind::object, PyType>> result;
+ result.reserve(container.size());
+ for (int i = 0, e = container.size(); i < e; ++i) {
+ result.push_back(PyType(context->getRef(),
+ mlirValueGetType(container.getElement(i).get()))
+ .maybeDownCast());
+ }
+ return result;
+}
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+ static constexpr const char *pyClassName = "OpResultList";
+ using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+ PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumResults(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all results in this result list.");
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResultList &self)
+ -> nanobind::typed<nanobind::object, PyOpView> {
+ return self.operation->createOpView();
+ },
+ "Returns the operation that owns this result list.");
+ }
+
+ PyOperationRef &getOperation() { return operation; }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpResultList, PyOpResult>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumResults(operation->get());
+ }
+
+ PyOpResult getRawElement(intptr_t index) {
+ PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+ return PyOpResult(value);
+ }
+
+ PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpResultList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+/// Python wrapper for MlirBlockArgument.
+class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
+ static constexpr const char *pyClassName = "BlockArgument";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyBlockArgument &self) {
+ return PyBlock(self.getParentOperation(),
+ mlirBlockArgumentGetOwner(self.get()));
+ },
+ "Returns the block that owns this argument.");
+ c.def_prop_ro(
+ "arg_number",
+ [](PyBlockArgument &self) {
+ return mlirBlockArgumentGetArgNumber(self.get());
+ },
+ "Returns the position of this argument in the block's argument list.");
+ c.def(
+ "set_type",
+ [](PyBlockArgument &self, PyType type) {
+ return mlirBlockArgumentSetType(self.get(), type);
+ },
+ nanobind::arg("type"), "Sets the type of this block argument.");
+ c.def(
+ "set_location",
+ [](PyBlockArgument &self, PyLocation loc) {
+ return mlirBlockArgumentSetLocation(self.get(), loc);
+ },
+ nanobind::arg("loc"), "Sets the location of this block argument.");
+ }
+};
+
+/// A list of block arguments. Internally, these are stored as consecutive
+/// elements, random access is cheap. The argument list is associated with the
+/// operation that contains the block (detached blocks are not allowed in
+/// Python bindings) and extends its lifetime.
+class PyBlockArgumentList
+ : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
+public:
+ static constexpr const char *pyClassName = "BlockArgumentList";
+ using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
+ PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumArguments(block) : length,
+ step),
+ operation(std::move(operation)), block(block) {}
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyBlockArgumentList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all arguments in this argument list.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
+ /// Returns the number of arguments in the list.
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirBlockGetNumArguments(block);
+ }
+
+ /// Returns `pos`-the element in the list.
+ PyBlockArgument getRawElement(intptr_t pos) {
+ MlirValue argument = mlirBlockGetArgument(block, pos);
+ return PyBlockArgument(operation, argument);
+ }
+
+ /// Returns a sublist of this list.
+ PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyBlockArgumentList(operation, block, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ MlirBlock block;
+};
+
+/// A list of operation operands. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) operand list is associated
+/// with the operation whose operands these are, and thus extends the lifetime
+/// of this operation.
+class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
+public:
+ static constexpr const char *pyClassName = "OpOperandList";
+ using SliceableT = Sliceable<PyOpOperandList, PyValue>;
+
+ PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumOperands(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
+ void dunderSetItem(intptr_t index, PyValue value) {
+ index = wrapIndex(index);
+ mlirOperationSetOperand(operation->get(), index, value.get());
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpOperandList::dunderSetItem,
+ nanobind::arg("index"), nanobind::arg("value"),
+ "Sets the operand at the specified index to a new value.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpOperandList, PyValue>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumOperands(operation->get());
+ }
+
+ PyValue getRawElement(intptr_t pos) {
+ MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(operand))
+ owner = mlirOpResultGetOwner(operand);
+ else if (mlirValueIsABlockArgument(operand))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
+ else
+ assert(false && "Value must be an block arg or op result.");
+ PyOperationRef pyOwner =
+ PyOperation::forOperation(operation->getContext(), owner);
+ return PyValue(pyOwner, operand);
+ }
+
+ PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpOperandList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+/// A list of operation successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation whose successors these are, and thus extends
+/// the lifetime of this operation.
+class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "OpSuccessors";
+
+ PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumSuccessors(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
+ void dunderSetItem(intptr_t index, PyBlock block) {
+ index = wrapIndex(index);
+ mlirOperationSetSuccessor(operation->get(), index, block.get());
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nanobind::arg("index"),
+ nanobind::arg("block"),
+ "Sets the successor block at the specified index.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpSuccessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumSuccessors(operation->get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpSuccessors(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+/// A list of block successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation and block whose successors these are, and thus
+/// extends the lifetime of this operation and block.
+class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockSuccessors";
+
+ PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumSuccessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockSuccessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumSuccessors(block.get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyBlockSuccessors(block, operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
+/// A list of block predecessors. The (returned) predecessor list is
+/// associated with the operation and block whose predecessors these are, and
+/// thus extends the lifetime of this operation and block.
+///
+/// WARNING: This Sliceable is more expensive than the others here because
+/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
+/// operands) anew for each indexed access.
+class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockPredecessors";
+
+ PyBlockPredecessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumPredecessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockPredecessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumPredecessors(block.get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyBlockPredecessors(block, operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
+/// A list of operation attributes. Can be indexed by name, producing
+/// attributes, or by index, producing named attributes.
+class PyOpAttributeMap {
+public:
+ PyOpAttributeMap(PyOperationRef operation)
+ : operation(std::move(operation)) {}
+
+ nanobind::typed<nanobind::object, PyAttribute>
+ dunderGetItemNamed(const std::string &name) {
+ MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
+ toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr)) {
+ throw nanobind::key_error("attempt to access a non-existent attribute");
+ }
+ return PyAttribute(operation->getContext(), attr).maybeDownCast();
+ }
+
+ PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0 || index >= dunderLen()) {
+ throw nanobind::index_error("attempt to access out of bounds attribute");
+ }
+ MlirNamedAttribute namedAttr =
+ mlirOperationGetAttribute(operation->get(), index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data,
+ mlirIdentifierStr(namedAttr.name).length));
+ }
+
+ void dunderSetItem(const std::string &name, const PyAttribute &attr) {
+ mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
+ attr);
+ }
+
+ void dunderDelItem(const std::string &name) {
+ int removed = mlirOperationRemoveAttributeByName(operation->get(),
+ toMlirStringRef(name));
+ if (!removed)
+ throw nanobind::key_error("attempt to delete a non-existent attribute");
+ }
+
+ intptr_t dunderLen() {
+ return mlirOperationGetNumAttributes(operation->get());
+ }
+
+ bool dunderContains(const std::string &name) {
+ return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
+ operation->get(), toMlirStringRef(name)));
+ }
+
+ static void
+ forEachAttr(MlirOperation op,
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
+ intptr_t n = mlirOperationGetNumAttributes(op);
+ for (intptr_t i = 0; i < n; ++i) {
+ MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
+ MlirStringRef name = mlirIdentifierStr(na.name);
+ fn(name, na.attribute);
+ }
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpAttributeMap>(m, "OpAttributeMap")
+ .def("__contains__", &PyOpAttributeMap::dunderContains,
+ nanobind::arg("name"),
+ "Checks if an attribute with the given name exists in the map.")
+ .def("__len__", &PyOpAttributeMap::dunderLen,
+ "Returns the number of attributes in the map.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
+ nanobind::arg("name"), "Gets an attribute by name.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
+ nanobind::arg("index"), "Gets a named attribute by index.")
+ .def("__setitem__", &PyOpAttributeMap::dunderSetItem,
+ nanobind::arg("name"), nanobind::arg("attr"),
+ "Sets an attribute with the given name.")
+ .def("__delitem__", &PyOpAttributeMap::dunderDelItem,
+ nanobind::arg("name"), "Deletes an attribute with the given name.")
+ .def(
+ "__iter__",
+ [](PyOpAttributeMap &self) {
+ nanobind::list keys;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ keys.append(nanobind::str(name.data, name.length));
+ });
+ return nanobind::iter(keys);
+ },
+ "Iterates over attribute names.")
+ .def(
+ "keys",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ out.append(nanobind::str(name.data, name.length));
+ });
+ return out;
+ },
+ "Returns a list of attribute names.")
+ .def(
+ "values",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef, MlirAttribute attr) {
+ out.append(PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast());
+ });
+ return out;
+ },
+ "Returns a list of attribute values.")
+ .def(
+ "items",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute attr) {
+ out.append(nanobind::make_tuple(
+ nanobind::str(name.data, name.length),
+ PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast()));
+ });
+ return out;
+ },
+ "Returns a list of `(name, attribute)` tuples.");
+ }
+
+private:
+ PyOperationRef operation;
+};
+MlirValue getUniqueResult(MlirOperation operation);
} // namespace python
} // namespace mlir
@@ -1345,4 +2352,4 @@ struct type_caster<mlir::python::DefaultingPyLocation>
} // namespace detail
} // namespace nanobind
-#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
+#endif // MLIR_BINDINGS_PYTHON_IRCORE_H
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
similarity index 100%
rename from mlir/lib/Bindings/Python/NanobindUtils.h
rename to mlir/include/mlir/Bindings/Python/NanobindUtils.h
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 0d1d9e89f92f6..a87918a05b126 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Dialect/SMT.h"
#include "mlir-c/IR.h"
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/Globals.cpp
similarity index 97%
rename from mlir/lib/Bindings/Python/IRModule.cpp
rename to mlir/lib/Bindings/Python/Globals.cpp
index 0de2f1711829b..bc6b210426221 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -6,25 +6,27 @@
//
//===----------------------------------------------------------------------===//
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include <optional>
#include <vector>
-#include "Globals.h"
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/Globals.h"
+// clang-format off
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
+// clang-format on
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
using namespace mlir;
-using namespace mlir::python;
// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------
+namespace mlir::python {
PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
@@ -265,3 +267,7 @@ bool PyGlobals::TracebackLoc::isUserTracebackFilename(
}
return isUserTracebackFilenameCache[file];
}
+
+nanobind::ft_mutex PyGlobalDebugFlag::mutex;
+
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 7147f2cbad149..624d8f0fa57ce 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -13,11 +13,13 @@
#include <utility>
#include <vector>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
+#include "mlir/Bindings/Python/IRCore.h"
+// clang-format off
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
#include "mlir-c/IntegerSet.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Support/LLVM.h"
@@ -509,7 +511,8 @@ PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) {
rawIntegerSet);
}
-void mlir::python::populateIRAffine(nb::module_ &m) {
+namespace mlir::python {
+void populateIRAffine(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of PyAffineExpr and derived classes.
//----------------------------------------------------------------------------
@@ -995,3 +998,4 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
PyIntegerSetConstraint::bind(m);
PyIntegerSetConstraintList::bind(m);
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c0a945e3f4f3b..36367e658697c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -12,12 +12,12 @@
#include <string_view>
#include <utility>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
@@ -1799,7 +1799,8 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
"Returns the value of the string attribute as `bytes`");
}
-void mlir::python::populateIRAttributes(nb::module_ &m) {
+namespace mlir::python {
+void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
@@ -1851,4 +1852,18 @@ void mlir::python::populateIRAttributes(nb::module_ &m) {
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
+ nb::register_exception_translator([](const std::exception_ptr &p,
+ void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 168c57955af07..88cffb64906d7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
+// clang-format off
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
@@ -22,6 +24,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
+#include <iostream>
#include <optional>
namespace nb = nanobind;
@@ -33,504 +36,7 @@ using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
-static const char kModuleParseDocstring[] =
- R"(Parses a module's assembly format from a string.
-
-Returns a new MlirModule or raises an MLIRError if the parsing fails.
-
-See also: https://mlir.llvm.org/docs/LangRef/
-)";
-
-static const char kDumpDocstring[] =
- "Dumps a debug representation of the object to stderr.";
-
-static const char kValueReplaceAllUsesExceptDocstring[] =
- R"(Replace all uses of this value with the `with` value, except for those
-in `exceptions`. `exceptions` can be either a single operation or a list of
-operations.
-)";
-
-//------------------------------------------------------------------------------
-// Utilities.
-//------------------------------------------------------------------------------
-
-/// Helper for creating an @classmethod.
-template <class Func, typename... Args>
-static nb::object classmethod(Func f, Args... args) {
- nb::object cf = nb::cpp_function(f, args...);
- return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
-}
-
-static nb::object
-createCustomDialectWrapper(const std::string &dialectNamespace,
- nb::object dialectDescriptor) {
- auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
- if (!dialectClass) {
- // Use the base class.
- return nb::cast(PyDialect(std::move(dialectDescriptor)));
- }
-
- // Create the custom implementation.
- return (*dialectClass)(std::move(dialectDescriptor));
-}
-
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(std::string_view s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(const nb::bytes &s) {
- return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
-}
-
-/// Create a block, using the current location context if no locations are
-/// specified.
-static MlirBlock createBlock(const nb::sequence &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- SmallVector<MlirType> argTypes;
- argTypes.reserve(nb::len(pyArgTypes));
- for (const auto &pyType : pyArgTypes)
- argTypes.push_back(nb::cast<PyType &>(pyType));
-
- SmallVector<MlirLocation> argLocs;
- if (pyArgLocs) {
- argLocs.reserve(nb::len(*pyArgLocs));
- for (const auto &pyLoc : *pyArgLocs)
- argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
- } else if (!argTypes.empty()) {
- argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
- }
-
- if (argTypes.size() != argLocs.size())
- throw nb::value_error(("Expected " + Twine(argTypes.size()) +
- " locations, got: " + Twine(argLocs.size()))
- .str()
- .c_str());
- return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
-}
-
-/// Wrapper for the global LLVM debugging flag.
-struct PyGlobalDebugFlag {
- static void set(nb::object &o, bool enable) {
- nb::ft_lock_guard lock(mutex);
- mlirEnableGlobalDebug(enable);
- }
-
- static bool get(const nb::object &) {
- nb::ft_lock_guard lock(mutex);
- return mlirIsGlobalDebugEnabled();
- }
-
- static void bind(nb::module_ &m) {
- // Debug flags.
- nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
- .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
- .def_static(
- "set_types",
- [](const std::string &type) {
- nb::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugType(type.c_str());
- },
- "types"_a, "Sets specific debug types to be produced by LLVM.")
- .def_static(
- "set_types",
- [](const std::vector<std::string> &types) {
- std::vector<const char *> pointers;
- pointers.reserve(types.size());
- for (const std::string &str : types)
- pointers.push_back(str.c_str());
- nb::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
- },
- "types"_a,
- "Sets multiple specific debug types to be produced by LLVM.");
- }
-
-private:
- static nb::ft_mutex mutex;
-};
-
-nb::ft_mutex PyGlobalDebugFlag::mutex;
-
-struct PyAttrBuilderMap {
- static bool dunderContains(const std::string &attributeKind) {
- return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
- }
- static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
- auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
- if (!builder)
- throw nb::key_error(attributeKind.c_str());
- return *builder;
- }
- static void dunderSetItemNamed(const std::string &attributeKind,
- nb::callable func, bool replace) {
- PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
- replace);
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
- .def_static("contains", &PyAttrBuilderMap::dunderContains,
- "attribute_kind"_a,
- "Checks whether an attribute builder is registered for the "
- "given attribute kind.")
- .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
- "attribute_kind"_a,
- "Gets the registered attribute builder for the given "
- "attribute kind.")
- .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
- "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
- "Register an attribute builder for building MLIR "
- "attributes from Python values.");
- }
-};
-
-//------------------------------------------------------------------------------
-// PyBlock
-//------------------------------------------------------------------------------
-
-nb::object PyBlock::getCapsule() {
- return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
-}
-
-//------------------------------------------------------------------------------
-// Collections.
-//------------------------------------------------------------------------------
-
-namespace {
-
-class PyRegionIterator {
-public:
- PyRegionIterator(PyOperationRef operation, int nextIndex)
- : operation(std::move(operation)), nextIndex(nextIndex) {}
-
- PyRegionIterator &dunderIter() { return *this; }
-
- PyRegion dunderNext() {
- operation->checkValid();
- if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
- throw nb::stop_iteration();
- }
- MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
- return PyRegion(operation, region);
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyRegionIterator>(m, "RegionIterator")
- .def("__iter__", &PyRegionIterator::dunderIter,
- "Returns an iterator over the regions in the operation.")
- .def("__next__", &PyRegionIterator::dunderNext,
- "Returns the next region in the iteration.");
- }
-
-private:
- PyOperationRef operation;
- intptr_t nextIndex = 0;
-};
-
-/// Regions of an op are fixed length and indexed numerically so are represented
-/// with a sequence-like container.
-class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
-public:
- static constexpr const char *pyClassName = "RegionSequence";
-
- PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumRegions(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- PyRegionIterator dunderIter() {
- operation->checkValid();
- return PyRegionIterator(operation, startIndex);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__iter__", &PyRegionList::dunderIter,
- "Returns an iterator over the regions in the sequence.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyRegionList, PyRegion>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumRegions(operation->get());
- }
-
- PyRegion getRawElement(intptr_t pos) {
- operation->checkValid();
- return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
- }
-
- PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyRegionList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
-class PyBlockIterator {
-public:
- PyBlockIterator(PyOperationRef operation, MlirBlock next)
- : operation(std::move(operation)), next(next) {}
-
- PyBlockIterator &dunderIter() { return *this; }
-
- PyBlock dunderNext() {
- operation->checkValid();
- if (mlirBlockIsNull(next)) {
- throw nb::stop_iteration();
- }
-
- PyBlock returnBlock(operation, next);
- next = mlirBlockGetNextInRegion(next);
- return returnBlock;
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyBlockIterator>(m, "BlockIterator")
- .def("__iter__", &PyBlockIterator::dunderIter,
- "Returns an iterator over the blocks in the operation's region.")
- .def("__next__", &PyBlockIterator::dunderNext,
- "Returns the next block in the iteration.");
- }
-
-private:
- PyOperationRef operation;
- MlirBlock next;
-};
-
-/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
-/// we present them as a more full-featured list-like container but optimize
-/// it for forward iteration. Blocks are always owned by a region.
-class PyBlockList {
-public:
- PyBlockList(PyOperationRef operation, MlirRegion region)
- : operation(std::move(operation)), region(region) {}
-
- PyBlockIterator dunderIter() {
- operation->checkValid();
- return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
- }
-
- intptr_t dunderLen() {
- operation->checkValid();
- intptr_t count = 0;
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- count += 1;
- block = mlirBlockGetNextInRegion(block);
- }
- return count;
- }
-
- PyBlock dunderGetItem(intptr_t index) {
- operation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nb::index_error("attempt to access out of bounds block");
- }
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- if (index == 0) {
- return PyBlock(operation, block);
- }
- block = mlirBlockGetNextInRegion(block);
- index -= 1;
- }
- throw nb::index_error("attempt to access out of bounds block");
- }
-
- PyBlock appendBlock(const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- operation->checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- mlirRegionAppendOwnedBlock(region, block);
- return PyBlock(operation, block);
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyBlockList>(m, "BlockList")
- .def("__getitem__", &PyBlockList::dunderGetItem,
- "Returns the block at the specified index.")
- .def("__iter__", &PyBlockList::dunderIter,
- "Returns an iterator over blocks in the operation's region.")
- .def("__len__", &PyBlockList::dunderLen,
- "Returns the number of blocks in the operation's region.")
- .def("append", &PyBlockList::appendBlock,
- R"(
- Appends a new block, with argument types as positional args.
-
- Returns:
- The created block.
- )",
- nb::arg("args"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt);
- }
-
-private:
- PyOperationRef operation;
- MlirRegion region;
-};
-
-class PyOperationIterator {
-public:
- PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
- : parentOperation(std::move(parentOperation)), next(next) {}
-
- PyOperationIterator &dunderIter() { return *this; }
-
- nb::typed<nb::object, PyOpView> dunderNext() {
- parentOperation->checkValid();
- if (mlirOperationIsNull(next)) {
- throw nb::stop_iteration();
- }
-
- PyOperationRef returnOperation =
- PyOperation::forOperation(parentOperation->getContext(), next);
- next = mlirOperationGetNextInBlock(next);
- return returnOperation->createOpView();
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOperationIterator>(m, "OperationIterator")
- .def("__iter__", &PyOperationIterator::dunderIter,
- "Returns an iterator over the operations in an operation's block.")
- .def("__next__", &PyOperationIterator::dunderNext,
- "Returns the next operation in the iteration.");
- }
-
-private:
- PyOperationRef parentOperation;
- MlirOperation next;
-};
-
-/// Operations are exposed by the C-API as a forward-only linked list. In
-/// Python, we present them as a more full-featured list-like container but
-/// optimize it for forward iteration. Iterable operations are always owned
-/// by a block.
-class PyOperationList {
-public:
- PyOperationList(PyOperationRef parentOperation, MlirBlock block)
- : parentOperation(std::move(parentOperation)), block(block) {}
-
- PyOperationIterator dunderIter() {
- parentOperation->checkValid();
- return PyOperationIterator(parentOperation,
- mlirBlockGetFirstOperation(block));
- }
-
- intptr_t dunderLen() {
- parentOperation->checkValid();
- intptr_t count = 0;
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- count += 1;
- childOp = mlirOperationGetNextInBlock(childOp);
- }
- return count;
- }
-
- nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
- parentOperation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nb::index_error("attempt to access out of bounds operation");
- }
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- if (index == 0) {
- return PyOperation::forOperation(parentOperation->getContext(), childOp)
- ->createOpView();
- }
- childOp = mlirOperationGetNextInBlock(childOp);
- index -= 1;
- }
- throw nb::index_error("attempt to access out of bounds operation");
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOperationList>(m, "OperationList")
- .def("__getitem__", &PyOperationList::dunderGetItem,
- "Returns the operation at the specified index.")
- .def("__iter__", &PyOperationList::dunderIter,
- "Returns an iterator over operations in the list.")
- .def("__len__", &PyOperationList::dunderLen,
- "Returns the number of operations in the list.");
- }
-
-private:
- PyOperationRef parentOperation;
- MlirBlock block;
-};
-
-class PyOpOperand {
-public:
- PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
-
- nb::typed<nb::object, PyOpView> getOwner() {
- MlirOperation owner = mlirOpOperandGetOwner(opOperand);
- PyMlirContextRef context =
- PyMlirContext::forContext(mlirOperationGetContext(owner));
- return PyOperation::forOperation(context, owner)->createOpView();
- }
-
- size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOpOperand>(m, "OpOperand")
- .def_prop_ro("owner", &PyOpOperand::getOwner,
- "Returns the operation that owns this operand.")
- .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
- "Returns the operand number in the owning operation.");
- }
-
-private:
- MlirOpOperand opOperand;
-};
-
-class PyOpOperandIterator {
-public:
- PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
-
- PyOpOperandIterator &dunderIter() { return *this; }
-
- PyOpOperand dunderNext() {
- if (mlirOpOperandIsNull(opOperand))
- throw nb::stop_iteration();
-
- PyOpOperand returnOpOperand(opOperand);
- opOperand = mlirOpOperandGetNextUse(opOperand);
- return returnOpOperand;
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
- .def("__iter__", &PyOpOperandIterator::dunderIter,
- "Returns an iterator over operands.")
- .def("__next__", &PyOpOperandIterator::dunderNext,
- "Returns the next operand in the iteration.");
- }
-
-private:
- MlirOpOperand opOperand;
-};
-
-} // namespace
-
+namespace mlir::python {
//------------------------------------------------------------------------------
// PyMlirContext
//------------------------------------------------------------------------------
@@ -1413,8 +919,12 @@ nb::object PyOperation::create(std::string_view name,
// Construct the operation.
PyMlirContext::ErrorCapture errors(location.getContext());
MlirOperation operation = mlirOperationCreate(&state);
- if (!operation.ptr)
+ if (!operation.ptr) {
+ for (auto take : errors.take()) {
+ std::cout << take.message << "\n";
+ }
throw MLIRError("Operation creation failed", errors.take());
+ }
PyOperationRef created =
PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);
@@ -1448,163 +958,6 @@ void PyOperation::erase() {
mlirOperationDestroy(operation);
}
-namespace {
-/// CRTP base class for Python MLIR values that subclass Value and should be
-/// castable from it. The value hierarchy is one level deep and is not supposed
-/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- // and redefine bindDerived.
- using ClassTy = nb::class_<DerivedTy, PyValue>;
- using IsAFunctionTy = bool (*)(MlirValue);
-
- PyConcreteValue() = default;
- PyConcreteValue(PyOperationRef operationRef, MlirValue value)
- : PyValue(operationRef, value) {}
- PyConcreteValue(PyValue &orig)
- : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
-
- /// Attempts to cast the original value to the derived type and throws on
- /// type mismatches.
- static MlirValue castFrom(PyValue &orig) {
- if (!DerivedTy::isaFunction(orig.get())) {
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
- throw nb::value_error((Twine("Cannot cast value to ") +
- DerivedTy::pyClassName + " (from " + origRepr +
- ")")
- .str()
- .c_str());
- }
- return orig.get();
- }
-
- /// Binds the Python module objects to functions of this class.
- static void bind(nb::module_ &m) {
- auto cls = ClassTy(
- m, DerivedTy::pyClassName, nb::is_generic(),
- nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
- .str()
- .c_str()));
- cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
- cls.def_static(
- "isinstance",
- [](PyValue &otherValue) -> bool {
- return DerivedTy::isaFunction(otherValue);
- },
- nb::arg("other_value"));
- cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
- return self.maybeDownCast();
- });
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
-} // namespace
-
-/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
- static constexpr const char *pyClassName = "OpResult";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
- assert(mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in "
- "the IR");
- return self.getParentOperation()->createOpView();
- },
- "Returns the operation that produces this result.");
- c.def_prop_ro(
- "result_number",
- [](PyOpResult &self) {
- return mlirOpResultGetResultNumber(self.get());
- },
- "Returns the position of this result in the operation's result list.");
- }
-};
-
-/// Returns the list of types of the values held by container.
-template <typename Container>
-static std::vector<nb::typed<nb::object, PyType>>
-getValueTypes(Container &container, PyMlirContextRef &context) {
- std::vector<nb::typed<nb::object, PyType>> result;
- result.reserve(container.size());
- for (int i = 0, e = container.size(); i < e; ++i) {
- result.push_back(PyType(context->getRef(),
- mlirValueGetType(container.getElement(i).get()))
- .maybeDownCast());
- }
- return result;
-}
-
-/// A list of operation results. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) result list is associated
-/// with the operation whose results these are, and thus extends the lifetime of
-/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
-public:
- static constexpr const char *pyClassName = "OpResultList";
- using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
-
- PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumResults(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyOpResultList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all results in this result list.");
- c.def_prop_ro(
- "owner",
- [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
- return self.operation->createOpView();
- },
- "Returns the operation that owns this result list.");
- }
-
- PyOperationRef &getOperation() { return operation; }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpResultList, PyOpResult>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumResults(operation->get());
- }
-
- PyOpResult getRawElement(intptr_t index) {
- PyValue value(operation, mlirOperationGetResult(operation->get(), index));
- return PyOpResult(value);
- }
-
- PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpResultList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
//------------------------------------------------------------------------------
// PyOpView
//------------------------------------------------------------------------------
@@ -1706,7 +1059,7 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
}
}
-static MlirValue getUniqueResult(MlirOperation operation) {
+MlirValue getUniqueResult(MlirOperation operation) {
auto numResults = mlirOperationGetNumResults(operation);
if (numResults != 1) {
auto name = mlirIdentifierStr(mlirOperationGetName(operation));
@@ -2319,2648 +1672,11 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
}
-namespace {
-
-/// Python wrapper for MlirBlockArgument.
-class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
- static constexpr const char *pyClassName = "BlockArgument";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyBlockArgument &self) {
- return PyBlock(self.getParentOperation(),
- mlirBlockArgumentGetOwner(self.get()));
- },
- "Returns the block that owns this argument.");
- c.def_prop_ro(
- "arg_number",
- [](PyBlockArgument &self) {
- return mlirBlockArgumentGetArgNumber(self.get());
- },
- "Returns the position of this argument in the block's argument list.");
- c.def(
- "set_type",
- [](PyBlockArgument &self, PyType type) {
- return mlirBlockArgumentSetType(self.get(), type);
- },
- nb::arg("type"), "Sets the type of this block argument.");
- c.def(
- "set_location",
- [](PyBlockArgument &self, PyLocation loc) {
- return mlirBlockArgumentSetLocation(self.get(), loc);
- },
- nb::arg("loc"), "Sets the location of this block argument.");
- }
-};
-
-/// A list of block arguments. Internally, these are stored as consecutive
-/// elements, random access is cheap. The argument list is associated with the
-/// operation that contains the block (detached blocks are not allowed in
-/// Python bindings) and extends its lifetime.
-class PyBlockArgumentList
- : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
-public:
- static constexpr const char *pyClassName = "BlockArgumentList";
- using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
-
- PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumArguments(block) : length,
- step),
- operation(std::move(operation)), block(block) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyBlockArgumentList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all arguments in this argument list.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
-
- /// Returns the number of arguments in the list.
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirBlockGetNumArguments(block);
- }
-
- /// Returns `pos`-the element in the list.
- PyBlockArgument getRawElement(intptr_t pos) {
- MlirValue argument = mlirBlockGetArgument(block, pos);
- return PyBlockArgument(operation, argument);
- }
-
- /// Returns a sublist of this list.
- PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockArgumentList(operation, block, startIndex, length, step);
- }
-
- PyOperationRef operation;
- MlirBlock block;
-};
-
-/// A list of operation operands. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) operand list is associated
-/// with the operation whose operands these are, and thus extends the lifetime
-/// of this operation.
-class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
-public:
- static constexpr const char *pyClassName = "OpOperandList";
- using SliceableT = Sliceable<PyOpOperandList, PyValue>;
-
- PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumOperands(operation->get())
- : length,
- step),
- operation(operation) {}
-
- void dunderSetItem(intptr_t index, PyValue value) {
- index = wrapIndex(index);
- mlirOperationSetOperand(operation->get(), index, value.get());
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpOperandList::dunderSetItem, nb::arg("index"),
- nb::arg("value"),
- "Sets the operand at the specified index to a new value.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpOperandList, PyValue>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumOperands(operation->get());
- }
-
- PyValue getRawElement(intptr_t pos) {
- MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
- MlirOperation owner;
- if (mlirValueIsAOpResult(operand))
- owner = mlirOpResultGetOwner(operand);
- else if (mlirValueIsABlockArgument(operand))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
- else
- assert(false && "Value must be an block arg or op result.");
- PyOperationRef pyOwner =
- PyOperation::forOperation(operation->getContext(), owner);
- return PyValue(pyOwner, operand);
- }
-
- PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpOperandList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
-/// A list of operation successors. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) successor list is
-/// associated with the operation whose successors these are, and thus extends
-/// the lifetime of this operation.
-class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "OpSuccessors";
-
- PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumSuccessors(operation->get())
- : length,
- step),
- operation(operation) {}
-
- void dunderSetItem(intptr_t index, PyBlock block) {
- index = wrapIndex(index);
- mlirOperationSetSuccessor(operation->get(), index, block.get());
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg("index"),
- nb::arg("block"), "Sets the successor block at the specified index.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpSuccessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumSuccessors(operation->get());
- }
-
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
- return PyBlock(operation, block);
- }
-
- PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpSuccessors(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
-/// A list of block successors. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) successor list is
-/// associated with the operation and block whose successors these are, and thus
-/// extends the lifetime of this operation and block.
-class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "BlockSuccessors";
-
- PyBlockSuccessors(PyBlock block, PyOperationRef operation,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumSuccessors(block.get())
- : length,
- step),
- operation(operation), block(block) {}
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockSuccessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumSuccessors(block.get());
- }
-
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
-
- PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyBlockSuccessors(block, operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
- PyBlock block;
-};
-
-/// A list of block predecessors. The (returned) predecessor list is
-/// associated with the operation and block whose predecessors these are, and
-/// thus extends the lifetime of this operation and block.
-///
-/// WARNING: This Sliceable is more expensive than the others here because
-/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
-/// operands) anew for each indexed access.
-class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "BlockPredecessors";
-
- PyBlockPredecessors(PyBlock block, PyOperationRef operation,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumPredecessors(block.get())
- : length,
- step),
- operation(operation), block(block) {}
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockPredecessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumPredecessors(block.get());
- }
-
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
-
- PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockPredecessors(block, operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
- PyBlock block;
-};
-
-/// A list of operation attributes. Can be indexed by name, producing
-/// attributes, or by index, producing named attributes.
-class PyOpAttributeMap {
-public:
- PyOpAttributeMap(PyOperationRef operation)
- : operation(std::move(operation)) {}
-
- nb::typed<nb::object, PyAttribute>
- dunderGetItemNamed(const std::string &name) {
- MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (mlirAttributeIsNull(attr)) {
- throw nb::key_error("attempt to access a non-existent attribute");
- }
- return PyAttribute(operation->getContext(), attr).maybeDownCast();
- }
-
- PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0 || index >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr =
- mlirOperationGetAttribute(operation->get(), index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data,
- mlirIdentifierStr(namedAttr.name).length));
- }
-
- void dunderSetItem(const std::string &name, const PyAttribute &attr) {
- mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
- attr);
- }
-
- void dunderDelItem(const std::string &name) {
- int removed = mlirOperationRemoveAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (!removed)
- throw nb::key_error("attempt to delete a non-existent attribute");
- }
-
- intptr_t dunderLen() {
- return mlirOperationGetNumAttributes(operation->get());
- }
-
- bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
- operation->get(), toMlirStringRef(name)));
- }
-
- static void
- forEachAttr(MlirOperation op,
- llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
- intptr_t n = mlirOperationGetNumAttributes(op);
- for (intptr_t i = 0; i < n; ++i) {
- MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
- MlirStringRef name = mlirIdentifierStr(na.name);
- fn(name, na.attribute);
- }
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
- .def("__contains__", &PyOpAttributeMap::dunderContains, nb::arg("name"),
- "Checks if an attribute with the given name exists in the map.")
- .def("__len__", &PyOpAttributeMap::dunderLen,
- "Returns the number of attributes in the map.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
- nb::arg("name"), "Gets an attribute by name.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
- nb::arg("index"), "Gets a named attribute by index.")
- .def("__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg("name"),
- nb::arg("attr"), "Sets an attribute with the given name.")
- .def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"),
- "Deletes an attribute with the given name.")
- .def(
- "__iter__",
- [](PyOpAttributeMap &self) {
- nb::list keys;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- keys.append(nb::str(name.data, name.length));
- });
- return nb::iter(keys);
- },
- "Iterates over attribute names.")
- .def(
- "keys",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- out.append(nb::str(name.data, name.length));
- });
- return out;
- },
- "Returns a list of attribute names.")
- .def(
- "values",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef, MlirAttribute attr) {
- out.append(PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast());
- });
- return out;
- },
- "Returns a list of attribute values.")
- .def(
- "items",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute attr) {
- out.append(nb::make_tuple(
- nb::str(name.data, name.length),
- PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast()));
- });
- return out;
- },
- "Returns a list of `(name, attribute)` tuples.");
- }
-
-private:
- PyOperationRef operation;
-};
-
-// see
-// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
-
-#ifndef _Py_CAST
-#define _Py_CAST(type, expr) ((type)(expr))
-#endif
-
-// Static inline functions should use _Py_NULL rather than using directly NULL
-// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
-// _Py_NULL is defined as nullptr.
-#ifndef _Py_NULL
-#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
- (defined(__cplusplus) && __cplusplus >= 201103)
-#define _Py_NULL nullptr
-#else
-#define _Py_NULL NULL
-#endif
-#endif
-
-// Python 3.10.0a3
-#if PY_VERSION_HEX < 0x030A00A3
-
-// bpo-42262 added Py_XNewRef()
-#if !defined(Py_XNewRef)
-[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
- Py_XINCREF(obj);
- return obj;
-}
-#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
-#endif
-
-// bpo-42262 added Py_NewRef()
-#if !defined(Py_NewRef)
-[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
- Py_INCREF(obj);
- return obj;
-}
-#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
-#endif
-
-#endif // Python 3.10.0a3
-
-// Python 3.9.0b1
-#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
-
-// bpo-40429 added PyThreadState_GetFrame()
-PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
- assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
- return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
-}
-
-// bpo-40421 added PyFrame_GetBack()
-PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
- assert(frame != _Py_NULL && "expected frame != _Py_NULL");
- return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
-}
-
-// bpo-40421 added PyFrame_GetCode()
-PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
- assert(frame != _Py_NULL && "expected frame != _Py_NULL");
- assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
- return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
-}
-
-#endif // Python 3.9.0b1
-
-MlirLocation tracebackToLocation(MlirContext ctx) {
- size_t framesLimit =
- PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
- // Use a thread_local here to avoid requiring a large amount of space.
- thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
- frames;
- size_t count = 0;
-
- nb::gil_scoped_acquire acquire;
- PyThreadState *tstate = PyThreadState_GET();
- PyFrameObject *next;
- PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
- // In the increment expression:
- // 1. get the next prev frame;
- // 2. decrement the ref count on the current frame (in order that it can get
- // gc'd, along with any objects in its closure and etc);
- // 3. set current = next.
- for (; pyFrame != nullptr && count < framesLimit;
- next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
- PyCodeObject *code = PyFrame_GetCode(pyFrame);
- auto fileNameStr =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
- llvm::StringRef fileName(fileNameStr);
- if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
- continue;
-
- // co_qualname and PyCode_Addr2Location added in py3.11
-#if PY_VERSION_HEX < 0x030B00F0
- std::string name =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
- llvm::StringRef funcName(name);
- int startLine = PyFrame_GetLineNumber(pyFrame);
- MlirLocation loc =
- mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
-#else
- std::string name =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
- llvm::StringRef funcName(name);
- int startLine, startCol, endLine, endCol;
- int lasti = PyFrame_GetLasti(pyFrame);
- if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
- &endCol)) {
- throw nb::python_error();
- }
- MlirLocation loc = mlirLocationFileLineColRangeGet(
- ctx, wrap(fileName), startLine, startCol, endLine, endCol);
-#endif
-
- frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
- ++count;
- }
- // When the loop breaks (after the last iter), current frame (if non-null)
- // is leaked without this.
- Py_XDECREF(pyFrame);
-
- if (count == 0)
- return mlirLocationUnknownGet(ctx);
-
- MlirLocation callee = frames[0];
- assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
- if (count == 1)
- return callee;
-
- MlirLocation caller = frames[count - 1];
- assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
- for (int i = count - 2; i >= 1; i--)
- caller = mlirLocationCallSiteGet(frames[i], caller);
-
- return mlirLocationCallSiteGet(callee, caller);
-}
-
-PyLocation
-maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
- if (location.has_value())
- return location.value();
- if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
- return DefaultingPyLocation::resolve();
-
- PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
- MlirLocation mlirLoc = tracebackToLocation(ctx.get());
- PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
- return {ref, mlirLoc};
-}
-
-} // namespace
-
-//------------------------------------------------------------------------------
-// Populates the core exports of the 'ir' submodule.
-//------------------------------------------------------------------------------
-
-void mlir::python::populateIRCore(nb::module_ &m) {
- // disable leak warnings which tend to be false positives.
- nb::set_leak_warnings(false);
- //----------------------------------------------------------------------------
- // Enums.
- //----------------------------------------------------------------------------
- nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
- .value("ERROR", MlirDiagnosticError)
- .value("WARNING", MlirDiagnosticWarning)
- .value("NOTE", MlirDiagnosticNote)
- .value("REMARK", MlirDiagnosticRemark);
-
- nb::enum_<MlirWalkOrder>(m, "WalkOrder")
- .value("PRE_ORDER", MlirWalkPreOrder)
- .value("POST_ORDER", MlirWalkPostOrder);
-
- nb::enum_<MlirWalkResult>(m, "WalkResult")
- .value("ADVANCE", MlirWalkResultAdvance)
- .value("INTERRUPT", MlirWalkResultInterrupt)
- .value("SKIP", MlirWalkResultSkip);
-
- //----------------------------------------------------------------------------
- // Mapping of Diagnostics.
- //----------------------------------------------------------------------------
- nb::class_<PyDiagnostic>(m, "Diagnostic")
- .def_prop_ro("severity", &PyDiagnostic::getSeverity,
- "Returns the severity of the diagnostic.")
- .def_prop_ro("location", &PyDiagnostic::getLocation,
- "Returns the location associated with the diagnostic.")
- .def_prop_ro("message", &PyDiagnostic::getMessage,
- "Returns the message text of the diagnostic.")
- .def_prop_ro("notes", &PyDiagnostic::getNotes,
- "Returns a tuple of attached note diagnostics.")
- .def(
- "__str__",
- [](PyDiagnostic &self) -> nb::str {
- if (!self.isValid())
- return nb::str("<Invalid Diagnostic>");
- return self.getMessage();
- },
- "Returns the diagnostic message as a string.");
-
- nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
- .def(
- "__init__",
- [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
- new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
- },
- "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
- .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
- "The severity level of the diagnostic.")
- .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
- "The location associated with the diagnostic.")
- .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
- "The message text of the diagnostic.")
- .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
- "List of attached note diagnostics.")
- .def(
- "__str__",
- [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
- "Returns the diagnostic message as a string.");
-
- nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
- .def("detach", &PyDiagnosticHandler::detach,
- "Detaches the diagnostic handler from the context.")
- .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
- "Returns True if the handler is attached to a context.")
- .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
- "Returns True if an error was encountered during diagnostic "
- "handling.")
- .def("__enter__", &PyDiagnosticHandler::contextEnter,
- "Enters the diagnostic handler as a context manager.")
- .def("__exit__", &PyDiagnosticHandler::contextExit,
- nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none(),
- "Exits the diagnostic handler context manager.");
-
- // Expose DefaultThreadPool to python
- nb::class_<PyThreadPool>(m, "ThreadPool")
- .def(
- "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
- "Creates a new thread pool with default concurrency.")
- .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
- "Returns the maximum number of threads in the pool.")
- .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
- "Returns the raw pointer to the LLVM thread pool as a string.");
-
- nb::class_<PyMlirContext>(m, "Context")
- .def(
- "__init__",
- [](PyMlirContext &self) {
- MlirContext context = mlirContextCreateWithThreading(false);
- new (&self) PyMlirContext(context);
- },
- R"(
- Creates a new MLIR context.
-
- The context is the top-level container for all MLIR objects. It owns the storage
- for types, attributes, locations, and other core IR objects. A context can be
- configured to allow or disallow unregistered dialects and can have dialects
- loaded on-demand.)")
- .def_static("_get_live_count", &PyMlirContext::getLiveCount,
- "Gets the number of live Context objects.")
- .def(
- "_get_context_again",
- [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
- PyMlirContextRef ref = PyMlirContext::forContext(self.get());
- return ref.releaseObject();
- },
- "Gets another reference to the same context.")
- .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
- "Gets the number of live modules owned by this context.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule,
- "Gets a capsule wrapping the MlirContext.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyMlirContext::createFromCapsule,
- "Creates a Context from a capsule wrapping MlirContext.")
- .def("__enter__", &PyMlirContext::contextEnter,
- "Enters the context as a context manager.")
- .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none(),
- "Exits the context manager.")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/)
- -> std::optional<nb::typed<nb::object, PyMlirContext>> {
- auto *context = PyThreadContextEntry::getDefaultContext();
- if (!context)
- return {};
- return nb::cast(context);
- },
- nb::sig("def current(/) -> Context | None"),
- "Gets the Context bound to the current thread or returns None if no "
- "context is set.")
- .def_prop_ro(
- "dialects",
- [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Gets a container for accessing dialects by name.")
- .def_prop_ro(
- "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Alias for `dialects`.")
- .def(
- "get_dialect_descriptor",
- [=](PyMlirContext &self, std::string &name) {
- MlirDialect dialect = mlirContextGetOrLoadDialect(
- self.get(), {name.data(), name.size()});
- if (mlirDialectIsNull(dialect)) {
- throw nb::value_error(
- (Twine("Dialect '") + name + "' not found").str().c_str());
- }
- return PyDialectDescriptor(self.getRef(), dialect);
- },
- nb::arg("dialect_name"),
- "Gets or loads a dialect by name, returning its descriptor object.")
- .def_prop_rw(
- "allow_unregistered_dialects",
- [](PyMlirContext &self) -> bool {
- return mlirContextGetAllowUnregisteredDialects(self.get());
- },
- [](PyMlirContext &self, bool value) {
- mlirContextSetAllowUnregisteredDialects(self.get(), value);
- },
- "Controls whether unregistered dialects are allowed in this context.")
- .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
- nb::arg("callback"),
- "Attaches a diagnostic handler that will receive callbacks.")
- .def(
- "enable_multithreading",
- [](PyMlirContext &self, bool enable) {
- mlirContextEnableMultithreading(self.get(), enable);
- },
- nb::arg("enable"),
- R"(
- Enables or disables multi-threading support in the context.
-
- Args:
- enable: Whether to enable (True) or disable (False) multi-threading.
- )")
- .def(
- "set_thread_pool",
- [](PyMlirContext &self, PyThreadPool &pool) {
- // we should disable multi-threading first before setting
- // new thread pool otherwise the assert in
- // MLIRContext::setThreadPool will be raised.
- mlirContextEnableMultithreading(self.get(), false);
- mlirContextSetThreadPool(self.get(), pool.get());
- },
- R"(
- Sets a custom thread pool for the context to use.
-
- Args:
- pool: A ThreadPool object to use for parallel operations.
-
- Note:
- Multi-threading is automatically disabled before setting the thread pool.)")
- .def(
- "get_num_threads",
- [](PyMlirContext &self) {
- return mlirContextGetNumThreads(self.get());
- },
- "Gets the number of threads in the context's thread pool.")
- .def(
- "_mlir_thread_pool_ptr",
- [](PyMlirContext &self) {
- MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
- std::stringstream ss;
- ss << pool.ptr;
- return ss.str();
- },
- "Gets the raw pointer to the LLVM thread pool as a string.")
- .def(
- "is_registered_operation",
- [](PyMlirContext &self, std::string &name) {
- return mlirContextIsRegisteredOperation(
- self.get(), MlirStringRef{name.data(), name.size()});
- },
- nb::arg("operation_name"),
- R"(
- Checks whether an operation with the given name is registered.
-
- Args:
- operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
-
- Returns:
- True if the operation is registered, False otherwise.)")
- .def(
- "append_dialect_registry",
- [](PyMlirContext &self, PyDialectRegistry ®istry) {
- mlirContextAppendDialectRegistry(self.get(), registry);
- },
- nb::arg("registry"),
- R"(
- Appends the contents of a dialect registry to the context.
-
- Args:
- registry: A DialectRegistry containing dialects to append.)")
- .def_prop_rw("emit_error_diagnostics",
- &PyMlirContext::getEmitErrorDiagnostics,
- &PyMlirContext::setEmitErrorDiagnostics,
- R"(
- Controls whether error diagnostics are emitted to diagnostic handlers.
-
- By default, error diagnostics are captured and reported through MLIRError exceptions.)")
- .def(
- "load_all_available_dialects",
- [](PyMlirContext &self) {
- mlirContextLoadAllAvailableDialects(self.get());
- },
- R"(
- Loads all dialects available in the registry into the context.
-
- This eagerly loads all dialects that have been registered, making them
- immediately available for use.)");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialectDescriptor
- //----------------------------------------------------------------------------
- nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
- .def_prop_ro(
- "namespace",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns = mlirDialectGetNamespace(self.get());
- return nb::str(ns.data, ns.length);
- },
- "Returns the namespace of the dialect.")
- .def(
- "__repr__",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns = mlirDialectGetNamespace(self.get());
- std::string repr("<DialectDescriptor ");
- repr.append(ns.data, ns.length);
- repr.append(">");
- return repr;
- },
- nb::sig("def __repr__(self) -> str"),
- "Returns a string representation of the dialect descriptor.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialects
- //----------------------------------------------------------------------------
- nb::class_<PyDialects>(m, "Dialects")
- .def(
- "__getitem__",
- [=](PyDialects &self, std::string keyName) {
- MlirDialect dialect =
- self.getDialectForKey(keyName, /*attrError=*/false);
- nb::object descriptor =
- nb::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(keyName, std::move(descriptor));
- },
- "Gets a dialect by name using subscript notation.")
- .def(
- "__getattr__",
- [=](PyDialects &self, std::string attrName) {
- MlirDialect dialect =
- self.getDialectForKey(attrName, /*attrError=*/true);
- nb::object descriptor =
- nb::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(attrName, std::move(descriptor));
- },
- "Gets a dialect by name using attribute notation.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialect
- //----------------------------------------------------------------------------
- nb::class_<PyDialect>(m, "Dialect")
- .def(nb::init<nb::object>(), nb::arg("descriptor"),
- "Creates a Dialect from a DialectDescriptor.")
- .def_prop_ro(
- "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
- "Returns the DialectDescriptor for this dialect.")
- .def(
- "__repr__",
- [](const nb::object &self) {
- auto clazz = self.attr("__class__");
- return nb::str("<Dialect ") +
- self.attr("descriptor").attr("namespace") +
- nb::str(" (class ") + clazz.attr("__module__") +
- nb::str(".") + clazz.attr("__name__") + nb::str(")>");
- },
- nb::sig("def __repr__(self) -> str"),
- "Returns a string representation of the dialect.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialectRegistry
- //----------------------------------------------------------------------------
- nb::class_<PyDialectRegistry>(m, "DialectRegistry")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule,
- "Gets a capsule wrapping the MlirDialectRegistry.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyDialectRegistry::createFromCapsule,
- "Creates a DialectRegistry from a capsule wrapping "
- "`MlirDialectRegistry`.")
- .def(nb::init<>(), "Creates a new empty dialect registry.");
-
- //----------------------------------------------------------------------------
- // Mapping of Location
- //----------------------------------------------------------------------------
- nb::class_<PyLocation>(m, "Location")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule,
- "Gets a capsule wrapping the MlirLocation.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule,
- "Creates a Location from a capsule wrapping MlirLocation.")
- .def("__enter__", &PyLocation::contextEnter,
- "Enters the location as a context manager.")
- .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none(),
- "Exits the location context manager.")
- .def(
- "__eq__",
- [](PyLocation &self, PyLocation &other) -> bool {
- return mlirLocationEqual(self, other);
- },
- "Compares two locations for equality.")
- .def(
- "__eq__", [](PyLocation &self, nb::object other) { return false; },
- "Compares location with non-location object (always returns False).")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/) -> std::optional<PyLocation *> {
- auto *loc = PyThreadContextEntry::getDefaultLocation();
- if (!loc)
- return std::nullopt;
- return loc;
- },
- // clang-format off
- nb::sig("def current(/) -> Location | None"),
- // clang-format on
- "Gets the Location bound to the current thread or raises ValueError.")
- .def_static(
- "unknown",
- [](DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationUnknownGet(context->get()));
- },
- nb::arg("context") = nb::none(),
- "Gets a Location representing an unknown location.")
- .def_static(
- "callsite",
- [](PyLocation callee, const std::vector<PyLocation> &frames,
- DefaultingPyMlirContext context) {
- if (frames.empty())
- throw nb::value_error("No caller frames provided.");
- MlirLocation caller = frames.back().get();
- for (const PyLocation &frame :
- llvm::reverse(llvm::ArrayRef(frames).drop_back()))
- caller = mlirLocationCallSiteGet(frame.get(), caller);
- return PyLocation(context->getRef(),
- mlirLocationCallSiteGet(callee.get(), caller));
- },
- nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
- "Gets a Location representing a caller and callsite.")
- .def("is_a_callsite", mlirLocationIsACallSite,
- "Returns True if this location is a CallSiteLoc.")
- .def_prop_ro(
- "callee",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationCallSiteGetCallee(self));
- },
- "Gets the callee location from a CallSiteLoc.")
- .def_prop_ro(
- "caller",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationCallSiteGetCaller(self));
- },
- "Gets the caller location from a CallSiteLoc.")
- .def_static(
- "file",
- [](std::string filename, int line, int col,
- DefaultingPyMlirContext context) {
- return PyLocation(
- context->getRef(),
- mlirLocationFileLineColGet(
- context->get(), toMlirStringRef(filename), line, col));
- },
- nb::arg("filename"), nb::arg("line"), nb::arg("col"),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a file, line and column.")
- .def_static(
- "file",
- [](std::string filename, int startLine, int startCol, int endLine,
- int endCol, DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationFileLineColRangeGet(
- context->get(), toMlirStringRef(filename),
- startLine, startCol, endLine, endCol));
- },
- nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
- nb::arg("end_line"), nb::arg("end_col"),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a file, line and column range.")
- .def("is_a_file", mlirLocationIsAFileLineColRange,
- "Returns True if this location is a FileLineColLoc.")
- .def_prop_ro(
- "filename",
- [](MlirLocation loc) {
- return mlirIdentifierStr(
- mlirLocationFileLineColRangeGetFilename(loc));
- },
- "Gets the filename from a FileLineColLoc.")
- .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
- "Gets the start line number from a `FileLineColLoc`.")
- .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
- "Gets the start column number from a `FileLineColLoc`.")
- .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
- "Gets the end line number from a `FileLineColLoc`.")
- .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
- "Gets the end column number from a `FileLineColLoc`.")
- .def_static(
- "fused",
- [](const std::vector<PyLocation> &pyLocations,
- std::optional<PyAttribute> metadata,
- DefaultingPyMlirContext context) {
- llvm::SmallVector<MlirLocation, 4> locations;
- locations.reserve(pyLocations.size());
- for (auto &pyLocation : pyLocations)
- locations.push_back(pyLocation.get());
- MlirLocation location = mlirLocationFusedGet(
- context->get(), locations.size(), locations.data(),
- metadata ? metadata->get() : MlirAttribute{0});
- return PyLocation(context->getRef(), location);
- },
- nb::arg("locations"), nb::arg("metadata") = nb::none(),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a fused location with optional "
- "metadata.")
- .def("is_a_fused", mlirLocationIsAFused,
- "Returns True if this location is a `FusedLoc`.")
- .def_prop_ro(
- "locations",
- [](PyLocation &self) {
- unsigned numLocations = mlirLocationFusedGetNumLocations(self);
- std::vector<MlirLocation> locations(numLocations);
- if (numLocations)
- mlirLocationFusedGetLocations(self, locations.data());
- std::vector<PyLocation> pyLocations{};
- pyLocations.reserve(numLocations);
- for (unsigned i = 0; i < numLocations; ++i)
- pyLocations.emplace_back(self.getContext(), locations[i]);
- return pyLocations;
- },
- "Gets the list of locations from a `FusedLoc`.")
- .def_static(
- "name",
- [](std::string name, std::optional<PyLocation> childLoc,
- DefaultingPyMlirContext context) {
- return PyLocation(
- context->getRef(),
- mlirLocationNameGet(
- context->get(), toMlirStringRef(name),
- childLoc ? childLoc->get()
- : mlirLocationUnknownGet(context->get())));
- },
- nb::arg("name"), nb::arg("childLoc") = nb::none(),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a named location with optional child "
- "location.")
- .def("is_a_name", mlirLocationIsAName,
- "Returns True if this location is a `NameLoc`.")
- .def_prop_ro(
- "name_str",
- [](MlirLocation loc) {
- return mlirIdentifierStr(mlirLocationNameGetName(loc));
- },
- "Gets the name string from a `NameLoc`.")
- .def_prop_ro(
- "child_loc",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationNameGetChildLoc(self));
- },
- "Gets the child location from a `NameLoc`.")
- .def_static(
- "from_attr",
- [](PyAttribute &attribute, DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationFromAttribute(attribute));
- },
- nb::arg("attribute"), nb::arg("context") = nb::none(),
- "Gets a Location from a `LocationAttr`.")
- .def_prop_ro(
- "context",
- [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Location`.")
- .def_prop_ro(
- "attr",
- [](PyLocation &self) {
- return PyAttribute(self.getContext(),
- mlirLocationGetAttribute(self));
- },
- "Get the underlying `LocationAttr`.")
- .def(
- "emit_error",
- [](PyLocation &self, std::string message) {
- mlirEmitError(self, message.c_str());
- },
- nb::arg("message"),
- R"(
- Emits an error diagnostic at this location.
-
- Args:
- message: The error message to emit.)")
- .def(
- "__repr__",
- [](PyLocation &self) {
- PyPrintAccumulator printAccum;
- mlirLocationPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly representation of the location.");
-
- //----------------------------------------------------------------------------
- // Mapping of Module
- //----------------------------------------------------------------------------
- nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule,
- "Gets a capsule wrapping the MlirModule.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
- R"(
- Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
-
- This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
- prevent double-frees (of the underlying `mlir::Module`).)")
- .def("_clear_mlir_module", &PyModule::clearMlirModule,
- R"(
- Clears the internal MLIR module reference.
-
- This is used internally to prevent double-free when ownership is transferred
- via the C API capsule mechanism. Not intended for normal use.)")
- .def_static(
- "parse",
- [](const std::string &moduleAsm, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParse(
- context->get(), toMlirStringRef(moduleAsm));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "parse",
- [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParse(
- context->get(), toMlirStringRef(moduleAsm));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "parseFile",
- [](const std::string &path, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParseFromFile(
- context->get(), toMlirStringRef(path));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("path"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "create",
- [](const std::optional<PyLocation> &loc)
- -> nb::typed<nb::object, PyModule> {
- PyLocation pyLoc = maybeGetTracebackLocation(loc);
- MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("loc") = nb::none(), "Creates an empty module.")
- .def_prop_ro(
- "context",
- [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that created the `Module`.")
- .def_prop_ro(
- "operation",
- [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
- return PyOperation::forOperation(self.getContext(),
- mlirModuleGetOperation(self.get()),
- self.getRef().releaseObject())
- .releaseObject();
- },
- "Accesses the module as an operation.")
- .def_prop_ro(
- "body",
- [](PyModule &self) {
- PyOperationRef moduleOp = PyOperation::forOperation(
- self.getContext(), mlirModuleGetOperation(self.get()),
- self.getRef().releaseObject());
- PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
- return returnBlock;
- },
- "Return the block for this module.")
- .def(
- "dump",
- [](PyModule &self) {
- mlirOperationDump(mlirModuleGetOperation(self.get()));
- },
- kDumpDocstring)
- .def(
- "__str__",
- [](const nb::object &self) {
- // Defer to the operation's __str__.
- return self.attr("operation").attr("__str__")();
- },
- nb::sig("def __str__(self) -> str"),
- R"(
- Gets the assembly form of the operation with default options.
-
- If more advanced control over the assembly formatting or I/O options is needed,
- use the dedicated print or get_asm method, which supports keyword arguments to
- customize behavior.
- )")
- .def(
- "__eq__",
- [](PyModule &self, PyModule &other) {
- return mlirModuleEqual(self.get(), other.get());
- },
- "other"_a, "Compares two modules for equality.")
- .def(
- "__hash__",
- [](PyModule &self) { return mlirModuleHashValue(self.get()); },
- "Returns the hash value of the module.");
-
- //----------------------------------------------------------------------------
- // Mapping of Operation.
- //----------------------------------------------------------------------------
- nb::class_<PyOperationBase>(m, "_OperationBase")
- .def_prop_ro(
- MLIR_PYTHON_CAPI_PTR_ATTR,
- [](PyOperationBase &self) {
- return self.getOperation().getCapsule();
- },
- "Gets a capsule wrapping the `MlirOperation`.")
- .def(
- "__eq__",
- [](PyOperationBase &self, PyOperationBase &other) {
- return mlirOperationEqual(self.getOperation().get(),
- other.getOperation().get());
- },
- "Compares two operations for equality.")
- .def(
- "__eq__",
- [](PyOperationBase &self, nb::object other) { return false; },
- "Compares operation with non-operation object (always returns "
- "False).")
- .def(
- "__hash__",
- [](PyOperationBase &self) {
- return mlirOperationHashValue(self.getOperation().get());
- },
- "Returns the hash value of the operation.")
- .def_prop_ro(
- "attributes",
- [](PyOperationBase &self) {
- return PyOpAttributeMap(self.getOperation().getRef());
- },
- "Returns a dictionary-like map of operation attributes.")
- .def_prop_ro(
- "context",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
- PyOperation &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- return concreteOperation.getContext().getObject();
- },
- "Context that owns the operation.")
- .def_prop_ro(
- "name",
- [](PyOperationBase &self) {
- auto &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- MlirOperation operation = concreteOperation.get();
- return mlirIdentifierStr(mlirOperationGetName(operation));
- },
- "Returns the fully qualified name of the operation.")
- .def_prop_ro(
- "operands",
- [](PyOperationBase &self) {
- return PyOpOperandList(self.getOperation().getRef());
- },
- "Returns the list of operation operands.")
- .def_prop_ro(
- "regions",
- [](PyOperationBase &self) {
- return PyRegionList(self.getOperation().getRef());
- },
- "Returns the list of operation regions.")
- .def_prop_ro(
- "results",
- [](PyOperationBase &self) {
- return PyOpResultList(self.getOperation().getRef());
- },
- "Returns the list of Operation results.")
- .def_prop_ro(
- "result",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
- auto &operation = self.getOperation();
- return PyOpResult(operation.getRef(), getUniqueResult(operation))
- .maybeDownCast();
- },
- "Shortcut to get an op result if it has only one (throws an error "
- "otherwise).")
- .def_prop_rw(
- "location",
- [](PyOperationBase &self) {
- PyOperation &operation = self.getOperation();
- return PyLocation(operation.getContext(),
- mlirOperationGetLocation(operation.get()));
- },
- [](PyOperationBase &self, const PyLocation &location) {
- PyOperation &operation = self.getOperation();
- mlirOperationSetLocation(operation.get(), location.get());
- },
- nb::for_getter("Returns the source location the operation was "
- "defined or derived from."),
- nb::for_setter("Sets the source location the operation was defined "
- "or derived from."))
- .def_prop_ro(
- "parent",
- [](PyOperationBase &self)
- -> std::optional<nb::typed<nb::object, PyOperation>> {
- auto parent = self.getOperation().getParentOperation();
- if (parent)
- return parent->getObject();
- return {};
- },
- "Returns the parent operation, or `None` if at top level.")
- .def(
- "__str__",
- [](PyOperationBase &self) {
- return self.getAsm(/*binary=*/false,
- /*largeElementsLimit=*/std::nullopt,
- /*largeResourceLimit=*/std::nullopt,
- /*enableDebugInfo=*/false,
- /*prettyDebugInfo=*/false,
- /*printGenericOpForm=*/false,
- /*useLocalScope=*/false,
- /*useNameLocAsPrefix=*/false,
- /*assumeVerified=*/false,
- /*skipRegions=*/false);
- },
- nb::sig("def __str__(self) -> str"),
- "Returns the assembly form of the operation.")
- .def("print",
- nb::overload_cast<PyAsmState &, nb::object, bool>(
- &PyOperationBase::print),
- nb::arg("state"), nb::arg("file") = nb::none(),
- nb::arg("binary") = false,
- R"(
- Prints the assembly form of the operation to a file like object.
-
- Args:
- state: `AsmState` capturing the operation numbering and flags.
- file: Optional file like object to write to. Defaults to sys.stdout.
- binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
- .def("print",
- nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
- bool, bool, bool, bool, bool, bool, nb::object,
- bool, bool>(&PyOperationBase::print),
- // Careful: Lots of arguments must match up with print method.
- nb::arg("large_elements_limit") = nb::none(),
- nb::arg("large_resource_limit") = nb::none(),
- nb::arg("enable_debug_info") = false,
- nb::arg("pretty_debug_info") = false,
- nb::arg("print_generic_op_form") = false,
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
- nb::arg("binary") = false, nb::arg("skip_regions") = false,
- R"(
- Prints the assembly form of the operation to a file like object.
-
- Args:
- large_elements_limit: Whether to elide elements attributes above this
- number of elements. Defaults to None (no limit).
- large_resource_limit: Whether to elide resource attributes above this
- number of characters. Defaults to None (no limit). If large_elements_limit
- is set and this is None, the behavior will be to use large_elements_limit
- as large_resource_limit.
- enable_debug_info: Whether to print debug/location information. Defaults
- to False.
- pretty_debug_info: Whether to format debug information for easier reading
- by a human (warning: the result is unparseable). Defaults to False.
- print_generic_op_form: Whether to print the generic assembly forms of all
- ops. Defaults to False.
- use_local_scope: Whether to print in a way that is more optimized for
- multi-threaded access but may not be consistent with how the overall
- module prints.
- use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
- prefixes for the SSA identifiers. Defaults to False.
- assume_verified: By default, if not printing generic form, the verifier
- will be run and if it fails, generic form will be printed with a comment
- about failed verification. While a reasonable default for interactive use,
- for systematic use, it is often better for the caller to verify explicitly
- and report failures in a more robust fashion. Set this to True if doing this
- in order to avoid running a redundant verification. If the IR is actually
- invalid, behavior is undefined.
- file: The file like object to write to. Defaults to sys.stdout.
- binary: Whether to write bytes (True) or str (False). Defaults to False.
- skip_regions: Whether to skip printing regions. Defaults to False.)")
- .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
- nb::arg("desired_version") = nb::none(),
- R"(
- Write the bytecode form of the operation to a file like object.
-
- Args:
- file: The file like object to write to.
- desired_version: Optional version of bytecode to emit.
- Returns:
- The bytecode writer status.)")
- .def("get_asm", &PyOperationBase::getAsm,
- // Careful: Lots of arguments must match up with get_asm method.
- nb::arg("binary") = false,
- nb::arg("large_elements_limit") = nb::none(),
- nb::arg("large_resource_limit") = nb::none(),
- nb::arg("enable_debug_info") = false,
- nb::arg("pretty_debug_info") = false,
- nb::arg("print_generic_op_form") = false,
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
- R"(
- Gets the assembly form of the operation with all options available.
-
- Args:
- binary: Whether to return a bytes (True) or str (False) object. Defaults to
- False.
- ... others ...: See the print() method for common keyword arguments for
- configuring the printout.
- Returns:
- Either a bytes or str object, depending on the setting of the `binary`
- argument.)")
- .def("verify", &PyOperationBase::verify,
- "Verify the operation. Raises MLIRError if verification fails, and "
- "returns true otherwise.")
- .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
- "Puts self immediately after the other operation in its parent "
- "block.")
- .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
- "Puts self immediately before the other operation in its parent "
- "block.")
- .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
- nb::arg("other"),
- R"(
- Checks if this operation is before another in the same block.
-
- Args:
- other: Another operation in the same parent block.
-
- Returns:
- True if this operation is before `other` in the operation list of the parent block.)")
- .def(
- "clone",
- [](PyOperationBase &self,
- const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
- return self.getOperation().clone(ip);
- },
- nb::arg("ip") = nb::none(),
- R"(
- Creates a deep copy of the operation.
-
- Args:
- ip: Optional insertion point where the cloned operation should be inserted.
- If None, the current insertion point is used. If False, the operation
- remains detached.
-
- Returns:
- A new Operation that is a clone of this operation.)")
- .def(
- "detach_from_parent",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
- PyOperation &operation = self.getOperation();
- operation.checkValid();
- if (!operation.isAttached())
- throw nb::value_error("Detached operation has no parent.");
-
- operation.detachFromParent();
- return operation.createOpView();
- },
- "Detaches the operation from its parent block.")
- .def_prop_ro(
- "attached",
- [](PyOperationBase &self) {
- PyOperation &operation = self.getOperation();
- operation.checkValid();
- return operation.isAttached();
- },
- "Reports if the operation is attached to its parent block.")
- .def(
- "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
- R"(
- Erases the operation and frees its memory.
-
- Note:
- After erasing, any Python references to the operation become invalid.)")
- .def("walk", &PyOperationBase::walk, nb::arg("callback"),
- nb::arg("walk_order") = MlirWalkPostOrder,
- // clang-format off
- nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
- // clang-format on
- R"(
- Walks the operation tree with a callback function.
-
- Args:
- callback: A callable that takes an Operation and returns a WalkResult.
- walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
-
- nb::class_<PyOperation, PyOperationBase>(m, "Operation")
- .def_static(
- "create",
- [](std::string_view name,
- std::optional<std::vector<PyType *>> results,
- std::optional<std::vector<PyValue *>> operands,
- std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors, int regions,
- const std::optional<PyLocation> &location,
- const nb::object &maybeIp,
- bool inferType) -> nb::typed<nb::object, PyOperation> {
- // Unpack/validate operands.
- llvm::SmallVector<MlirValue, 4> mlirOperands;
- if (operands) {
- mlirOperands.reserve(operands->size());
- for (PyValue *operand : *operands) {
- if (!operand)
- throw nb::value_error("operand value cannot be None");
- mlirOperands.push_back(operand->get());
- }
- }
-
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- return PyOperation::create(name, results, mlirOperands, attributes,
- successors, regions, pyLoc, maybeIp,
- inferType);
- },
- nb::arg("name"), nb::arg("results") = nb::none(),
- nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
- nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
- nb::arg("infer_type") = false,
- R"(
- Creates a new operation.
-
- Args:
- name: Operation name (e.g. `dialect.operation`).
- results: Optional sequence of Type representing op result types.
- operands: Optional operands of the operation.
- attributes: Optional Dict of {str: Attribute}.
- successors: Optional List of Block for the operation's successors.
- regions: Number of regions to create (default = 0).
- location: Optional Location object (defaults to resolve from context manager).
- ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
- infer_type: Whether to infer result types (default = False).
- Returns:
- A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
- .def_static(
- "parse",
- [](const std::string &sourceStr, const std::string &sourceName,
- DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyOpView> {
- return PyOperation::parse(context->getRef(), sourceStr, sourceName)
- ->createOpView();
- },
- nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
- nb::arg("context") = nb::none(),
- "Parses an operation. Supports both text assembly format and binary "
- "bytecode format.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule,
- "Gets a capsule wrapping the MlirOperation.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyOperation::createFromCapsule,
- "Creates an Operation from a capsule wrapping MlirOperation.")
- .def_prop_ro(
- "operation",
- [](nb::object self) -> nb::typed<nb::object, PyOperation> {
- return self;
- },
- "Returns self (the operation).")
- .def_prop_ro(
- "opview",
- [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
- return self.createOpView();
- },
- R"(
- Returns an OpView of this operation.
-
- Note:
- If the operation has a registered and loaded dialect then this OpView will
- be concrete wrapper class.)")
- .def_prop_ro("block", &PyOperation::getBlock,
- "Returns the block containing this operation.")
- .def_prop_ro(
- "successors",
- [](PyOperationBase &self) {
- return PyOpSuccessors(self.getOperation().getRef());
- },
- "Returns the list of Operation successors.")
- .def(
- "replace_uses_of_with",
- [](PyOperation &self, PyValue &of, PyValue &with) {
- mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
- },
- "of"_a, "with_"_a,
- "Replaces uses of the 'of' value with the 'with' value inside the "
- "operation.")
- .def("_set_invalid", &PyOperation::setInvalid,
- "Invalidate the operation.");
-
- auto opViewClass =
- nb::class_<PyOpView, PyOperationBase>(m, "OpView")
- .def(nb::init<nb::typed<nb::object, PyOperation>>(),
- nb::arg("operation"))
- .def(
- "__init__",
- [](PyOpView *self, std::string_view name,
- std::tuple<int, bool> opRegionSpec,
- nb::object operandSegmentSpecObj,
- nb::object resultSegmentSpecObj,
- std::optional<nb::list> resultTypeList, nb::list operandList,
- std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions,
- const std::optional<PyLocation> &location,
- const nb::object &maybeIp) {
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- new (self) PyOpView(PyOpView::buildGeneric(
- name, opRegionSpec, operandSegmentSpecObj,
- resultSegmentSpecObj, resultTypeList, operandList,
- attributes, successors, regions, pyLoc, maybeIp));
- },
- nb::arg("name"), nb::arg("opRegionSpec"),
- nb::arg("operandSegmentSpecObj") = nb::none(),
- nb::arg("resultSegmentSpecObj") = nb::none(),
- nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
- nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(),
- nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
- nb::arg("ip") = nb::none())
- .def_prop_ro(
- "operation",
- [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
- return self.getOperationObject();
- })
- .def_prop_ro("opview",
- [](nb::object self) -> nb::typed<nb::object, PyOpView> {
- return self;
- })
- .def(
- "__str__",
- [](PyOpView &self) { return nb::str(self.getOperationObject()); })
- .def_prop_ro(
- "successors",
- [](PyOperationBase &self) {
- return PyOpSuccessors(self.getOperation().getRef());
- },
- "Returns the list of Operation successors.")
- .def(
- "_set_invalid",
- [](PyOpView &self) { self.getOperation().setInvalid(); },
- "Invalidate the operation.");
- opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
- opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
- opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
- // It is faster to pass the operation_name, ods_regions, and
- // ods_operand_segments/ods_result_segments as arguments to the constructor,
- // rather than to access them as attributes.
- opViewClass.attr("build_generic") = classmethod(
- [](nb::handle cls, std::optional<nb::list> resultTypeList,
- nb::list operandList, std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, std::optional<PyLocation> location,
- const nb::object &maybeIp) {
- std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
- std::tuple<int, bool> opRegionSpec =
- nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
- nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
- nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
- resultSegmentSpec, resultTypeList,
- operandList, attributes, successors,
- regions, pyLoc, maybeIp);
- },
- nb::arg("cls"), nb::arg("results") = nb::none(),
- nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
- nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
- "Builds a specific, generated OpView based on class level attributes.");
- opViewClass.attr("parse") = classmethod(
- [](const nb::object &cls, const std::string &sourceStr,
- const std::string &sourceName,
- DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
- PyOperationRef parsed =
- PyOperation::parse(context->getRef(), sourceStr, sourceName);
-
- // Check if the expected operation was parsed, and cast to to the
- // appropriate `OpView` subclass if successful.
- // NOTE: This accesses attributes that have been automatically added to
- // `OpView` subclasses, and is not intended to be used on `OpView`
- // directly.
- std::string clsOpName =
- nb::cast<std::string>(cls.attr("OPERATION_NAME"));
- MlirStringRef identifier =
- mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
- std::string_view parsedOpName(identifier.data, identifier.length);
- if (clsOpName != parsedOpName)
- throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
- parsedOpName + "'");
- return PyOpView::constructDerived(cls, parsed.getObject());
- },
- nb::arg("cls"), nb::arg("source"), nb::kw_only(),
- nb::arg("source_name") = "", nb::arg("context") = nb::none(),
- "Parses a specific, generated OpView based on class level attributes.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyRegion.
- //----------------------------------------------------------------------------
- nb::class_<PyRegion>(m, "Region")
- .def_prop_ro(
- "blocks",
- [](PyRegion &self) {
- return PyBlockList(self.getParentOperation(), self.get());
- },
- "Returns a forward-optimized sequence of blocks.")
- .def_prop_ro(
- "owner",
- [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
- return self.getParentOperation()->createOpView();
- },
- "Returns the operation owning this region.")
- .def(
- "__iter__",
- [](PyRegion &self) {
- self.checkValid();
- MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
- return PyBlockIterator(self.getParentOperation(), firstBlock);
- },
- "Iterates over blocks in the region.")
- .def(
- "__eq__",
- [](PyRegion &self, PyRegion &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two regions for pointer equality.")
- .def(
- "__eq__", [](PyRegion &self, nb::object &other) { return false; },
- "Compares region with non-region object (always returns False).");
-
- //----------------------------------------------------------------------------
- // Mapping of PyBlock.
- //----------------------------------------------------------------------------
- nb::class_<PyBlock>(m, "Block")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule,
- "Gets a capsule wrapping the MlirBlock.")
- .def_prop_ro(
- "owner",
- [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
- return self.getParentOperation()->createOpView();
- },
- "Returns the owning operation of this block.")
- .def_prop_ro(
- "region",
- [](PyBlock &self) {
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- return PyRegion(self.getParentOperation(), region);
- },
- "Returns the owning region of this block.")
- .def_prop_ro(
- "arguments",
- [](PyBlock &self) {
- return PyBlockArgumentList(self.getParentOperation(), self.get());
- },
- "Returns a list of block arguments.")
- .def(
- "add_argument",
- [](PyBlock &self, const PyType &type, const PyLocation &loc) {
- return PyBlockArgument(self.getParentOperation(),
- mlirBlockAddArgument(self.get(), type, loc));
- },
- "type"_a, "loc"_a,
- R"(
- Appends an argument of the specified type to the block.
-
- Args:
- type: The type of the argument to add.
- loc: The source location for the argument.
-
- Returns:
- The newly added block argument.)")
- .def(
- "erase_argument",
- [](PyBlock &self, unsigned index) {
- return mlirBlockEraseArgument(self.get(), index);
- },
- nb::arg("index"),
- R"(
- Erases the argument at the specified index.
-
- Args:
- index: The index of the argument to erase.)")
- .def_prop_ro(
- "operations",
- [](PyBlock &self) {
- return PyOperationList(self.getParentOperation(), self.get());
- },
- "Returns a forward-optimized sequence of operations.")
- .def_static(
- "create_at_start",
- [](PyRegion &parent, const nb::sequence &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- parent.checkValid();
- MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
- mlirRegionInsertOwnedBlock(parent, 0, block);
- return PyBlock(parent.getParentOperation(), block);
- },
- nb::arg("parent"), nb::arg("arg_types") = nb::list(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block at the beginning of the given "
- "region (with given argument types and locations).")
- .def(
- "append_to",
- [](PyBlock &self, PyRegion ®ion) {
- MlirBlock b = self.get();
- if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
- mlirBlockDetach(b);
- mlirRegionAppendOwnedBlock(region.get(), b);
- },
- nb::arg("region"),
- R"(
- Appends this block to a region.
-
- Transfers ownership if the block is currently owned by another region.
-
- Args:
- region: The region to append the block to.)")
- .def(
- "create_before",
- [](PyBlock &self, const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- self.checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
- return PyBlock(self.getParentOperation(), block);
- },
- nb::arg("arg_types"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block before this block "
- "(with given argument types and locations).")
- .def(
- "create_after",
- [](PyBlock &self, const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- self.checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
- return PyBlock(self.getParentOperation(), block);
- },
- nb::arg("arg_types"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block after this block "
- "(with given argument types and locations).")
- .def(
- "__iter__",
- [](PyBlock &self) {
- self.checkValid();
- MlirOperation firstOperation =
- mlirBlockGetFirstOperation(self.get());
- return PyOperationIterator(self.getParentOperation(),
- firstOperation);
- },
- "Iterates over operations in the block.")
- .def(
- "__eq__",
- [](PyBlock &self, PyBlock &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two blocks for pointer equality.")
- .def(
- "__eq__", [](PyBlock &self, nb::object &other) { return false; },
- "Compares block with non-block object (always returns False).")
- .def(
- "__hash__",
- [](PyBlock &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the block.")
- .def(
- "__str__",
- [](PyBlock &self) {
- self.checkValid();
- PyPrintAccumulator printAccum;
- mlirBlockPrint(self.get(), printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the block.")
- .def(
- "append",
- [](PyBlock &self, PyOperationBase &operation) {
- if (operation.getOperation().isAttached())
- operation.getOperation().detachFromParent();
-
- MlirOperation mlirOperation = operation.getOperation().get();
- mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
- operation.getOperation().setAttached(
- self.getParentOperation().getObject());
- },
- nb::arg("operation"),
- R"(
- Appends an operation to this block.
-
- If the operation is currently in another block, it will be moved.
-
- Args:
- operation: The operation to append to the block.)")
- .def_prop_ro(
- "successors",
- [](PyBlock &self) {
- return PyBlockSuccessors(self, self.getParentOperation());
- },
- "Returns the list of Block successors.")
- .def_prop_ro(
- "predecessors",
- [](PyBlock &self) {
- return PyBlockPredecessors(self, self.getParentOperation());
- },
- "Returns the list of Block predecessors.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyInsertionPoint.
- //----------------------------------------------------------------------------
-
- nb::class_<PyInsertionPoint>(m, "InsertionPoint")
- .def(nb::init<PyBlock &>(), nb::arg("block"),
- "Inserts after the last operation but still inside the block.")
- .def("__enter__", &PyInsertionPoint::contextEnter,
- "Enters the insertion point as a context manager.")
- .def("__exit__", &PyInsertionPoint::contextExit,
- nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none(),
- "Exits the insertion point context manager.")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/) {
- auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
- if (!ip)
- throw nb::value_error("No current InsertionPoint");
- return ip;
- },
- nb::sig("def current(/) -> InsertionPoint"),
- "Gets the InsertionPoint bound to the current thread or raises "
- "ValueError if none has been set.")
- .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
- "Inserts before a referenced operation.")
- .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
- nb::arg("block"),
- R"(
- Creates an insertion point at the beginning of a block.
-
- Args:
- block: The block at whose beginning operations should be inserted.
-
- Returns:
- An InsertionPoint at the block's beginning.)")
- .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
- nb::arg("block"),
- R"(
- Creates an insertion point before a block's terminator.
-
- Args:
- block: The block whose terminator to insert before.
-
- Returns:
- An InsertionPoint before the terminator.
-
- Raises:
- ValueError: If the block has no terminator.)")
- .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
- R"(
- Creates an insertion point immediately after an operation.
-
- Args:
- operation: The operation after which to insert.
-
- Returns:
- An InsertionPoint after the operation.)")
- .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
- R"(
- Inserts an operation at this insertion point.
-
- Args:
- operation: The operation to insert.)")
- .def_prop_ro(
- "block", [](PyInsertionPoint &self) { return self.getBlock(); },
- "Returns the block that this `InsertionPoint` points to.")
- .def_prop_ro(
- "ref_operation",
- [](PyInsertionPoint &self)
- -> std::optional<nb::typed<nb::object, PyOperation>> {
- auto refOperation = self.getRefOperation();
- if (refOperation)
- return refOperation->getObject();
- return {};
- },
- "The reference operation before which new operations are "
- "inserted, or None if the insertion point is at the end of "
- "the block.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyAttribute.
- //----------------------------------------------------------------------------
- nb::class_<PyAttribute>(m, "Attribute")
- // Delegate to the PyAttribute copy constructor, which will also lifetime
- // extend the backing context which owns the MlirAttribute.
- .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
- "Casts the passed attribute to the generic `Attribute`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule,
- "Gets a capsule wrapping the MlirAttribute.")
- .def_static(
- MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule,
- "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
- .def_static(
- "parse",
- [](const std::string &attrSpec, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyAttribute> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute attr = mlirAttributeParseGet(
- context->get(), toMlirStringRef(attrSpec));
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Unable to parse attribute", errors.take());
- return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- "Parses an attribute from an assembly form. Raises an `MLIRError` on "
- "failure.")
- .def_prop_ro(
- "context",
- [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Attribute`.")
- .def_prop_ro(
- "type",
- [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirAttributeGetType(self))
- .maybeDownCast();
- },
- "Returns the type of the `Attribute`.")
- .def(
- "get_named",
- [](PyAttribute &self, std::string name) {
- return PyNamedAttribute(self, std::move(name));
- },
- nb::keep_alive<0, 1>(),
- R"(
- Binds a name to the attribute, creating a `NamedAttribute`.
-
- Args:
- name: The name to bind to the `Attribute`.
-
- Returns:
- A `NamedAttribute` with the given name and this attribute.)")
- .def(
- "__eq__",
- [](PyAttribute &self, PyAttribute &other) { return self == other; },
- "Compares two attributes for equality.")
- .def(
- "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
- "Compares attribute with non-attribute object (always returns "
- "False).")
- .def(
- "__hash__",
- [](PyAttribute &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the attribute.")
- .def(
- "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
- kDumpDocstring)
- .def(
- "__str__",
- [](PyAttribute &self) {
- PyPrintAccumulator printAccum;
- mlirAttributePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the Attribute.")
- .def(
- "__repr__",
- [](PyAttribute &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, attribute values are generally considered useful and
- // are printed. This may need to be re-evaluated if debug dumps end
- // up being excessive.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Attribute(");
- mlirAttributePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the attribute.")
- .def_prop_ro(
- "typeid",
- [](PyAttribute &self) {
- MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
- assert(!mlirTypeIDIsNull(mlirTypeID) &&
- "mlirTypeID was expected to be non-null.");
- return PyTypeID(mlirTypeID);
- },
- "Returns the `TypeID` of the attribute.")
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
- return self.maybeDownCast();
- },
- "Downcasts the attribute to a more specific attribute if possible.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyNamedAttribute
- //----------------------------------------------------------------------------
- nb::class_<PyNamedAttribute>(m, "NamedAttribute")
- .def(
- "__repr__",
- [](PyNamedAttribute &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("NamedAttribute(");
- printAccum.parts.append(
- nb::str(mlirIdentifierStr(self.namedAttr.name).data,
- mlirIdentifierStr(self.namedAttr.name).length));
- printAccum.parts.append("=");
- mlirAttributePrint(self.namedAttr.attribute,
- printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the named attribute.")
- .def_prop_ro(
- "name",
- [](PyNamedAttribute &self) {
- return mlirIdentifierStr(self.namedAttr.name);
- },
- "The name of the `NamedAttribute` binding.")
- .def_prop_ro(
- "attr",
- [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
- nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
- "The underlying generic attribute of the `NamedAttribute` binding.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyType.
- //----------------------------------------------------------------------------
- nb::class_<PyType>(m, "Type")
- // Delegate to the PyType copy constructor, which will also lifetime
- // extend the backing context which owns the MlirType.
- .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
- "Casts the passed type to the generic `Type`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule,
- "Gets a capsule wrapping the `MlirType`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule,
- "Creates a Type from a capsule wrapping `MlirType`.")
- .def_static(
- "parse",
- [](std::string typeSpec,
- DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType type =
- mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
- if (mlirTypeIsNull(type))
- throw MLIRError("Unable to parse type", errors.take());
- return PyType(context.get()->getRef(), type).maybeDownCast();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- R"(
- Parses the assembly form of a type.
-
- Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
-
- See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
- .def_prop_ro(
- "context",
- [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Type`.")
- .def(
- "__eq__", [](PyType &self, PyType &other) { return self == other; },
- "Compares two types for equality.")
- .def(
- "__eq__", [](PyType &self, nb::object &other) { return false; },
- nb::arg("other").none(),
- "Compares type with non-type object (always returns False).")
- .def(
- "__hash__",
- [](PyType &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the `Type`.")
- .def(
- "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
- .def(
- "__str__",
- [](PyType &self) {
- PyPrintAccumulator printAccum;
- mlirTypePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the `Type`.")
- .def(
- "__repr__",
- [](PyType &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, types are an exception as they typically have compact
- // assembly forms and printing them is useful.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Type(");
- mlirTypePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the `Type`.")
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyType &self) -> nb::typed<nb::object, PyType> {
- return self.maybeDownCast();
- },
- "Downcasts the Type to a more specific `Type` if possible.")
- .def_prop_ro(
- "typeid",
- [](PyType &self) {
- MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
- if (!mlirTypeIDIsNull(mlirTypeID))
- return PyTypeID(mlirTypeID);
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
- throw nb::value_error(
- (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
- },
- "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
- "`Type` has no "
- "`TypeID`.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyTypeID.
- //----------------------------------------------------------------------------
- nb::class_<PyTypeID>(m, "TypeID")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule,
- "Gets a capsule wrapping the `MlirTypeID`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule,
- "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
- // Note, this tests whether the underlying TypeIDs are the same,
- // not whether the wrapper MlirTypeIDs are the same, nor whether
- // the Python objects are the same (i.e., PyTypeID is a value type).
- .def(
- "__eq__",
- [](PyTypeID &self, PyTypeID &other) { return self == other; },
- "Compares two `TypeID`s for equality.")
- .def(
- "__eq__",
- [](PyTypeID &self, const nb::object &other) { return false; },
- "Compares TypeID with non-TypeID object (always returns False).")
- // Note, this gives the hash value of the underlying TypeID, not the
- // hash value of the Python object, nor the hash value of the
- // MlirTypeID wrapper.
- .def(
- "__hash__",
- [](PyTypeID &self) {
- return static_cast<size_t>(mlirTypeIDHashValue(self));
- },
- "Returns the hash value of the `TypeID`.");
-
- //----------------------------------------------------------------------------
- // Mapping of Value.
- //----------------------------------------------------------------------------
- m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
-
- nb::class_<PyValue>(m, "Value", nb::is_generic(),
- nb::sig("class Value(Generic[_T])"))
- .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
- "Creates a Value reference from another `Value`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
- "Gets a capsule wrapping the `MlirValue`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule,
- "Creates a `Value` from a capsule wrapping `MlirValue`.")
- .def_prop_ro(
- "context",
- [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getParentOperation()->getContext().getObject();
- },
- "Context in which the value lives.")
- .def(
- "dump", [](PyValue &self) { mlirValueDump(self.get()); },
- kDumpDocstring)
- .def_prop_ro(
- "owner",
- [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
- MlirValue v = self.get();
- if (mlirValueIsAOpResult(v)) {
- assert(mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match "
- "that in "
- "the IR");
- return self.getParentOperation()->createOpView();
- }
-
- if (mlirValueIsABlockArgument(v)) {
- MlirBlock block = mlirBlockArgumentGetOwner(self.get());
- return nb::cast(PyBlock(self.getParentOperation(), block));
- }
-
- assert(false && "Value must be a block argument or an op result");
- return nb::none();
- },
- "Returns the owner of the value (`Operation` for results, `Block` "
- "for "
- "arguments).")
- .def_prop_ro(
- "uses",
- [](PyValue &self) {
- return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
- },
- "Returns an iterator over uses of this value.")
- .def(
- "__eq__",
- [](PyValue &self, PyValue &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two values for pointer equality.")
- .def(
- "__eq__", [](PyValue &self, nb::object other) { return false; },
- "Compares value with non-value object (always returns False).")
- .def(
- "__hash__",
- [](PyValue &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the value.")
- .def(
- "__str__",
- [](PyValue &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Value(");
- mlirValuePrint(self.get(), printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- R"(
- Returns the string form of the value.
-
- If the value is a block argument, this is the assembly form of its type and the
- position in the argument list. If the value is an operation result, this is
- equivalent to printing the operation that produced it.
- )")
- .def(
- "get_name",
- [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
- PyPrintAccumulator printAccum;
- MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- if (useNameLocAsPrefix)
- mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
- MlirAsmState valueState =
- mlirAsmStateCreateForValue(self.get(), flags);
- mlirValuePrintAsOperand(self.get(), valueState,
- printAccum.getCallback(),
- printAccum.getUserData());
- mlirOpPrintingFlagsDestroy(flags);
- mlirAsmStateDestroy(valueState);
- return printAccum.join();
- },
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- R"(
- Returns the string form of value as an operand.
-
- Args:
- use_local_scope: Whether to use local scope for naming.
- use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
-
- Returns:
- The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
- .def(
- "get_name",
- [](PyValue &self, PyAsmState &state) {
- PyPrintAccumulator printAccum;
- MlirAsmState valueState = state.get();
- mlirValuePrintAsOperand(self.get(), valueState,
- printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- nb::arg("state"),
- "Returns the string form of value as an operand (i.e., the ValueID).")
- .def_prop_ro(
- "type",
- [](PyValue &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getParentOperation()->getContext(),
- mlirValueGetType(self.get()))
- .maybeDownCast();
- },
- "Returns the type of the value.")
- .def(
- "set_type",
- [](PyValue &self, const PyType &type) {
- mlirValueSetType(self.get(), type);
- },
- nb::arg("type"), "Sets the type of the value.",
- nb::sig("def set_type(self, type: _T)"))
- .def(
- "replace_all_uses_with",
- [](PyValue &self, PyValue &with) {
- mlirValueReplaceAllUsesOfWith(self.get(), with.get());
- },
- "Replace all uses of value with the new value, updating anything in "
- "the IR that uses `self` to use the other value instead.")
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, PyOperation &exception) {
- MlirOperation exceptedUser = exception.get();
- mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, const nb::list &exceptions) {
- // Convert Python list to a SmallVector of MlirOperations
- llvm::SmallVector<MlirOperation> exceptionOps;
- for (nb::handle exception : exceptions) {
- exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
- }
-
- mlirValueReplaceAllUsesExcept(
- self, with, static_cast<intptr_t>(exceptionOps.size()),
- exceptionOps.data());
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, PyOperation &exception) {
- MlirOperation exceptedUser = exception.get();
- mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with,
- std::vector<PyOperation> &exceptions) {
- // Convert Python list to a SmallVector of MlirOperations
- llvm::SmallVector<MlirOperation> exceptionOps;
- for (PyOperation &exception : exceptions)
- exceptionOps.push_back(exception);
- mlirValueReplaceAllUsesExcept(
- self, with, static_cast<intptr_t>(exceptionOps.size()),
- exceptionOps.data());
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyValue &self) -> nb::typed<nb::object, PyValue> {
- return self.maybeDownCast();
- },
- "Downcasts the `Value` to a more specific kind if possible.")
- .def_prop_ro(
- "location",
- [](MlirValue self) {
- return PyLocation(
- PyMlirContext::forContext(mlirValueGetContext(self)),
- mlirValueGetLocation(self));
- },
- "Returns the source location of the value.");
-
- PyBlockArgument::bind(m);
- PyOpResult::bind(m);
- PyOpOperand::bind(m);
-
- nb::class_<PyAsmState>(m, "AsmState")
- .def(nb::init<PyValue &, bool>(), nb::arg("value"),
- nb::arg("use_local_scope") = false,
- R"(
- Creates an `AsmState` for consistent SSA value naming.
-
- Args:
- value: The value to create state for.
- use_local_scope: Whether to use local scope for naming.)")
- .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
- nb::arg("use_local_scope") = false,
- R"(
- Creates an AsmState for consistent SSA value naming.
-
- Args:
- op: The operation to create state for.
- use_local_scope: Whether to use local scope for naming.)");
-
- //----------------------------------------------------------------------------
- // Mapping of SymbolTable.
- //----------------------------------------------------------------------------
- nb::class_<PySymbolTable>(m, "SymbolTable")
- .def(nb::init<PyOperationBase &>(),
- R"(
- Creates a symbol table for an operation.
-
- Args:
- operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
-
- Raises:
- TypeError: If the operation is not a symbol table.)")
- .def(
- "__getitem__",
- [](PySymbolTable &self,
- const std::string &name) -> nb::typed<nb::object, PyOpView> {
- return self.dunderGetItem(name);
- },
- R"(
- Looks up a symbol by name in the symbol table.
-
- Args:
- name: The name of the symbol to look up.
-
- Returns:
- The operation defining the symbol.
-
- Raises:
- KeyError: If the symbol is not found.)")
- .def("insert", &PySymbolTable::insert, nb::arg("operation"),
- R"(
- Inserts a symbol operation into the symbol table.
-
- Args:
- operation: An operation with a symbol name to insert.
-
- Returns:
- The symbol name attribute of the inserted operation.
-
- Raises:
- ValueError: If the operation does not have a symbol name.)")
- .def("erase", &PySymbolTable::erase, nb::arg("operation"),
- R"(
- Erases a symbol operation from the symbol table.
-
- Args:
- operation: The symbol operation to erase.
-
- Note:
- The operation is also erased from the IR and invalidated.)")
- .def("__delitem__", &PySymbolTable::dunderDel,
- "Deletes a symbol by name from the symbol table.")
- .def(
- "__contains__",
- [](PySymbolTable &table, const std::string &name) {
- return !mlirOperationIsNull(mlirSymbolTableLookup(
- table, mlirStringRefCreate(name.data(), name.length())));
- },
- "Checks if a symbol with the given name exists in the table.")
- // Static helpers.
- .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
- nb::arg("symbol"), nb::arg("name"),
- "Sets the symbol name for a symbol operation.")
- .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
- nb::arg("symbol"),
- "Gets the symbol name from a symbol operation.")
- .def_static("get_visibility", &PySymbolTable::getVisibility,
- nb::arg("symbol"),
- "Gets the visibility attribute of a symbol operation.")
- .def_static("set_visibility", &PySymbolTable::setVisibility,
- nb::arg("symbol"), nb::arg("visibility"),
- "Sets the visibility attribute of a symbol operation.")
- .def_static("replace_all_symbol_uses",
- &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
- nb::arg("new_symbol"), nb::arg("from_op"),
- "Replaces all uses of a symbol with a new symbol name within "
- "the given operation.")
- .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
- nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
- nb::arg("callback"),
- "Walks symbol tables starting from an operation with a "
- "callback function.");
-
- // Container bindings.
- PyBlockArgumentList::bind(m);
- PyBlockIterator::bind(m);
- PyBlockList::bind(m);
- PyBlockSuccessors::bind(m);
- PyBlockPredecessors::bind(m);
- PyOperationIterator::bind(m);
- PyOperationList::bind(m);
- PyOpAttributeMap::bind(m);
- PyOpOperandIterator::bind(m);
- PyOpOperandList::bind(m);
- PyOpResultList::bind(m);
- PyOpSuccessors::bind(m);
- PyRegionIterator::bind(m);
- PyRegionList::bind(m);
-
- // Debug bindings.
- PyGlobalDebugFlag::bind(m);
-
- // Attribute builder getter.
- PyAttrBuilderMap::bind(m);
-
+void registerMLIRErrorInIRCore() {
nb::register_exception_translator([](const std::exception_ptr &p,
void *payload) {
- // We can't define exceptions with custom fields through pybind, so instead
- // the exception class is defined in python and imported here.
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
try {
if (p)
std::rethrow_exception(p);
@@ -4971,3 +1687,4 @@ void mlir::python::populateIRCore(nb::module_ &m) {
}
});
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 31d4798ffb906..f1e494c375523 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,7 +12,7 @@
#include <utility>
#include <vector>
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 34c5b8dd86a66..294ab91a059e2 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -7,14 +7,13 @@
//===----------------------------------------------------------------------===//
// clang-format off
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/IRTypes.h"
// clang-format on
#include <optional>
-#include "IRModule.h"
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
@@ -1144,7 +1143,8 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
} // namespace
-void mlir::python::populateIRTypes(nb::module_ &m) {
+namespace mlir::python {
+void populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
PyIndexType::bind(m);
@@ -1175,4 +1175,18 @@ void mlir::python::populateIRTypes(nb::module_ &m) {
PyTupleType::bind(m);
PyFunctionType::bind(m);
PyOpaqueType::bind(m);
+ nb::register_exception_translator([](const std::exception_ptr &p,
+ void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+});
+}
}
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index ba767ad6692cf..686c55ee1e6a8 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,18 +6,2275 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "Pass.h"
#include "Rewrite.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
using namespace mlir::python;
+static const char kModuleParseDocstring[] =
+ R"(Parses a module's assembly format from a string.
+
+Returns a new MlirModule or raises an MLIRError if the parsing fails.
+
+See also: https://mlir.llvm.org/docs/LangRef/
+)";
+
+static const char kDumpDocstring[] =
+ "Dumps a debug representation of the object to stderr.";
+
+static const char kValueReplaceAllUsesExceptDocstring[] =
+ R"(Replace all uses of this value with the `with` value, except for those
+in `exceptions`. `exceptions` can be either a single operation or a list of
+operations.
+)";
+
+namespace {
+// see
+// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
+
+#ifndef _Py_CAST
+#define _Py_CAST(type, expr) ((type)(expr))
+#endif
+
+// Static inline functions should use _Py_NULL rather than using directly NULL
+// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
+// _Py_NULL is defined as nullptr.
+#ifndef _Py_NULL
+#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
+ (defined(__cplusplus) && __cplusplus >= 201103)
+#define _Py_NULL nullptr
+#else
+#define _Py_NULL NULL
+#endif
+#endif
+
+// Python 3.10.0a3
+#if PY_VERSION_HEX < 0x030A00A3
+
+// bpo-42262 added Py_XNewRef()
+#if !defined(Py_XNewRef)
+[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
+ Py_XINCREF(obj);
+ return obj;
+}
+#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
+#endif
+
+// bpo-42262 added Py_NewRef()
+#if !defined(Py_NewRef)
+[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
+ Py_INCREF(obj);
+ return obj;
+}
+#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
+#endif
+
+#endif // Python 3.10.0a3
+
+// Python 3.9.0b1
+#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
+
+// bpo-40429 added PyThreadState_GetFrame()
+PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
+ assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
+}
+
+// bpo-40421 added PyFrame_GetBack()
+PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
+}
+
+// bpo-40421 added PyFrame_GetCode()
+PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
+ return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
+}
+
+#endif // Python 3.9.0b1
+
+MlirLocation tracebackToLocation(MlirContext ctx) {
+ size_t framesLimit =
+ PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
+ // Use a thread_local here to avoid requiring a large amount of space.
+ thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
+ frames;
+ size_t count = 0;
+
+ nb::gil_scoped_acquire acquire;
+ PyThreadState *tstate = PyThreadState_GET();
+ PyFrameObject *next;
+ PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
+ // In the increment expression:
+ // 1. get the next prev frame;
+ // 2. decrement the ref count on the current frame (in order that it can get
+ // gc'd, along with any objects in its closure and etc);
+ // 3. set current = next.
+ for (; pyFrame != nullptr && count < framesLimit;
+ next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
+ PyCodeObject *code = PyFrame_GetCode(pyFrame);
+ auto fileNameStr =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
+ llvm::StringRef fileName(fileNameStr);
+ if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
+ continue;
+
+ // co_qualname and PyCode_Addr2Location added in py3.11
+#if PY_VERSION_HEX < 0x030B00F0
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
+ llvm::StringRef funcName(name);
+ int startLine = PyFrame_GetLineNumber(pyFrame);
+ MlirLocation loc =
+ mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
+#else
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
+ llvm::StringRef funcName(name);
+ int startLine, startCol, endLine, endCol;
+ int lasti = PyFrame_GetLasti(pyFrame);
+ if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
+ &endCol)) {
+ throw nb::python_error();
+ }
+ MlirLocation loc = mlirLocationFileLineColRangeGet(
+ ctx, wrap(fileName), startLine, startCol, endLine, endCol);
+#endif
+
+ frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
+ ++count;
+ }
+ // When the loop breaks (after the last iter), current frame (if non-null)
+ // is leaked without this.
+ Py_XDECREF(pyFrame);
+
+ if (count == 0)
+ return mlirLocationUnknownGet(ctx);
+
+ MlirLocation callee = frames[0];
+ assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
+ if (count == 1)
+ return callee;
+
+ MlirLocation caller = frames[count - 1];
+ assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
+ for (int i = count - 2; i >= 1; i--)
+ caller = mlirLocationCallSiteGet(frames[i], caller);
+
+ return mlirLocationCallSiteGet(callee, caller);
+}
+
+PyLocation
+maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
+ if (location.has_value())
+ return location.value();
+ if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
+ return DefaultingPyLocation::resolve();
+
+ PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
+ MlirLocation mlirLoc = tracebackToLocation(ctx.get());
+ PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
+ return {ref, mlirLoc};
+}
+} // namespace
+
+//------------------------------------------------------------------------------
+// Populates the core exports of the 'ir' submodule.
+//------------------------------------------------------------------------------
+
+static void populateIRCore(nb::module_ &m) {
+ // disable leak warnings which tend to be false positives.
+ nb::set_leak_warnings(false);
+ //----------------------------------------------------------------------------
+ // Enums.
+ //----------------------------------------------------------------------------
+ nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
+ .value("ERROR", MlirDiagnosticError)
+ .value("WARNING", MlirDiagnosticWarning)
+ .value("NOTE", MlirDiagnosticNote)
+ .value("REMARK", MlirDiagnosticRemark);
+
+ nb::enum_<MlirWalkOrder>(m, "WalkOrder")
+ .value("PRE_ORDER", MlirWalkPreOrder)
+ .value("POST_ORDER", MlirWalkPostOrder);
+
+ nb::enum_<MlirWalkResult>(m, "WalkResult")
+ .value("ADVANCE", MlirWalkResultAdvance)
+ .value("INTERRUPT", MlirWalkResultInterrupt)
+ .value("SKIP", MlirWalkResultSkip);
+
+ //----------------------------------------------------------------------------
+ // Mapping of Diagnostics.
+ //----------------------------------------------------------------------------
+ nb::class_<PyDiagnostic>(m, "Diagnostic")
+ .def_prop_ro("severity", &PyDiagnostic::getSeverity,
+ "Returns the severity of the diagnostic.")
+ .def_prop_ro("location", &PyDiagnostic::getLocation,
+ "Returns the location associated with the diagnostic.")
+ .def_prop_ro("message", &PyDiagnostic::getMessage,
+ "Returns the message text of the diagnostic.")
+ .def_prop_ro("notes", &PyDiagnostic::getNotes,
+ "Returns a tuple of attached note diagnostics.")
+ .def(
+ "__str__",
+ [](PyDiagnostic &self) -> nb::str {
+ if (!self.isValid())
+ return nb::str("<Invalid Diagnostic>");
+ return self.getMessage();
+ },
+ "Returns the diagnostic message as a string.");
+
+ nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
+ .def(
+ "__init__",
+ [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
+ new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
+ },
+ "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
+ .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
+ "The severity level of the diagnostic.")
+ .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
+ "The location associated with the diagnostic.")
+ .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
+ "The message text of the diagnostic.")
+ .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
+ "List of attached note diagnostics.")
+ .def(
+ "__str__",
+ [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
+ "Returns the diagnostic message as a string.");
+
+ nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
+ .def("detach", &PyDiagnosticHandler::detach,
+ "Detaches the diagnostic handler from the context.")
+ .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
+ "Returns True if the handler is attached to a context.")
+ .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
+ "Returns True if an error was encountered during diagnostic "
+ "handling.")
+ .def("__enter__", &PyDiagnosticHandler::contextEnter,
+ "Enters the diagnostic handler as a context manager.")
+ .def("__exit__", &PyDiagnosticHandler::contextExit,
+ nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+ nb::arg("traceback").none(),
+ "Exits the diagnostic handler context manager.");
+
+ // Expose DefaultThreadPool to python
+ nb::class_<PyThreadPool>(m, "ThreadPool")
+ .def(
+ "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
+ "Creates a new thread pool with default concurrency.")
+ .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
+ "Returns the maximum number of threads in the pool.")
+ .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
+ "Returns the raw pointer to the LLVM thread pool as a string.");
+
+ nb::class_<PyMlirContext>(m, "Context")
+ .def(
+ "__init__",
+ [](PyMlirContext &self) {
+ MlirContext context = mlirContextCreateWithThreading(false);
+ new (&self) PyMlirContext(context);
+ },
+ R"(
+ Creates a new MLIR context.
+
+ The context is the top-level container for all MLIR objects. It owns the storage
+ for types, attributes, locations, and other core IR objects. A context can be
+ configured to allow or disallow unregistered dialects and can have dialects
+ loaded on-demand.)")
+ .def_static("_get_live_count", &PyMlirContext::getLiveCount,
+ "Gets the number of live Context objects.")
+ .def(
+ "_get_context_again",
+ [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
+ PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+ return ref.releaseObject();
+ },
+ "Gets another reference to the same context.")
+ .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
+ "Gets the number of live modules owned by this context.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule,
+ "Gets a capsule wrapping the MlirContext.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyMlirContext::createFromCapsule,
+ "Creates a Context from a capsule wrapping MlirContext.")
+ .def("__enter__", &PyMlirContext::contextEnter,
+ "Enters the context as a context manager.")
+ .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
+ nb::arg("exc_value").none(), nb::arg("traceback").none(),
+ "Exits the context manager.")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/)
+ -> std::optional<nb::typed<nb::object, PyMlirContext>> {
+ auto *context = PyThreadContextEntry::getDefaultContext();
+ if (!context)
+ return {};
+ return nb::cast(context);
+ },
+ nb::sig("def current(/) -> Context | None"),
+ "Gets the Context bound to the current thread or returns None if no "
+ "context is set.")
+ .def_prop_ro(
+ "dialects",
+ [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Gets a container for accessing dialects by name.")
+ .def_prop_ro(
+ "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Alias for `dialects`.")
+ .def(
+ "get_dialect_descriptor",
+ [=](PyMlirContext &self, std::string &name) {
+ MlirDialect dialect = mlirContextGetOrLoadDialect(
+ self.get(), {name.data(), name.size()});
+ if (mlirDialectIsNull(dialect)) {
+ throw nb::value_error(
+ (Twine("Dialect '") + name + "' not found").str().c_str());
+ }
+ return PyDialectDescriptor(self.getRef(), dialect);
+ },
+ nb::arg("dialect_name"),
+ "Gets or loads a dialect by name, returning its descriptor object.")
+ .def_prop_rw(
+ "allow_unregistered_dialects",
+ [](PyMlirContext &self) -> bool {
+ return mlirContextGetAllowUnregisteredDialects(self.get());
+ },
+ [](PyMlirContext &self, bool value) {
+ mlirContextSetAllowUnregisteredDialects(self.get(), value);
+ },
+ "Controls whether unregistered dialects are allowed in this context.")
+ .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
+ nb::arg("callback"),
+ "Attaches a diagnostic handler that will receive callbacks.")
+ .def(
+ "enable_multithreading",
+ [](PyMlirContext &self, bool enable) {
+ mlirContextEnableMultithreading(self.get(), enable);
+ },
+ nb::arg("enable"),
+ R"(
+ Enables or disables multi-threading support in the context.
+
+ Args:
+ enable: Whether to enable (True) or disable (False) multi-threading.
+ )")
+ .def(
+ "set_thread_pool",
+ [](PyMlirContext &self, PyThreadPool &pool) {
+ // we should disable multi-threading first before setting
+ // new thread pool otherwise the assert in
+ // MLIRContext::setThreadPool will be raised.
+ mlirContextEnableMultithreading(self.get(), false);
+ mlirContextSetThreadPool(self.get(), pool.get());
+ },
+ R"(
+ Sets a custom thread pool for the context to use.
+
+ Args:
+ pool: A ThreadPool object to use for parallel operations.
+
+ Note:
+ Multi-threading is automatically disabled before setting the thread pool.)")
+ .def(
+ "get_num_threads",
+ [](PyMlirContext &self) {
+ return mlirContextGetNumThreads(self.get());
+ },
+ "Gets the number of threads in the context's thread pool.")
+ .def(
+ "_mlir_thread_pool_ptr",
+ [](PyMlirContext &self) {
+ MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
+ std::stringstream ss;
+ ss << pool.ptr;
+ return ss.str();
+ },
+ "Gets the raw pointer to the LLVM thread pool as a string.")
+ .def(
+ "is_registered_operation",
+ [](PyMlirContext &self, std::string &name) {
+ return mlirContextIsRegisteredOperation(
+ self.get(), MlirStringRef{name.data(), name.size()});
+ },
+ nb::arg("operation_name"),
+ R"(
+ Checks whether an operation with the given name is registered.
+
+ Args:
+ operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
+
+ Returns:
+ True if the operation is registered, False otherwise.)")
+ .def(
+ "append_dialect_registry",
+ [](PyMlirContext &self, PyDialectRegistry ®istry) {
+ mlirContextAppendDialectRegistry(self.get(), registry);
+ },
+ nb::arg("registry"),
+ R"(
+ Appends the contents of a dialect registry to the context.
+
+ Args:
+ registry: A DialectRegistry containing dialects to append.)")
+ .def_prop_rw("emit_error_diagnostics",
+ &PyMlirContext::getEmitErrorDiagnostics,
+ &PyMlirContext::setEmitErrorDiagnostics,
+ R"(
+ Controls whether error diagnostics are emitted to diagnostic handlers.
+
+ By default, error diagnostics are captured and reported through MLIRError exceptions.)")
+ .def(
+ "load_all_available_dialects",
+ [](PyMlirContext &self) {
+ mlirContextLoadAllAvailableDialects(self.get());
+ },
+ R"(
+ Loads all dialects available in the registry into the context.
+
+ This eagerly loads all dialects that have been registered, making them
+ immediately available for use.)");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialectDescriptor
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
+ .def_prop_ro(
+ "namespace",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ return nb::str(ns.data, ns.length);
+ },
+ "Returns the namespace of the dialect.")
+ .def(
+ "__repr__",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ std::string repr("<DialectDescriptor ");
+ repr.append(ns.data, ns.length);
+ repr.append(">");
+ return repr;
+ },
+ nb::sig("def __repr__(self) -> str"),
+ "Returns a string representation of the dialect descriptor.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialects
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialects>(m, "Dialects")
+ .def(
+ "__getitem__",
+ [=](PyDialects &self, std::string keyName) {
+ MlirDialect dialect =
+ self.getDialectForKey(keyName, /*attrError=*/false);
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(keyName, std::move(descriptor));
+ },
+ "Gets a dialect by name using subscript notation.")
+ .def(
+ "__getattr__",
+ [=](PyDialects &self, std::string attrName) {
+ MlirDialect dialect =
+ self.getDialectForKey(attrName, /*attrError=*/true);
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(attrName, std::move(descriptor));
+ },
+ "Gets a dialect by name using attribute notation.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialect
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialect>(m, "Dialect")
+ .def(nb::init<nb::object>(), nb::arg("descriptor"),
+ "Creates a Dialect from a DialectDescriptor.")
+ .def_prop_ro(
+ "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
+ "Returns the DialectDescriptor for this dialect.")
+ .def(
+ "__repr__",
+ [](const nb::object &self) {
+ auto clazz = self.attr("__class__");
+ return nb::str("<Dialect ") +
+ self.attr("descriptor").attr("namespace") +
+ nb::str(" (class ") + clazz.attr("__module__") +
+ nb::str(".") + clazz.attr("__name__") + nb::str(")>");
+ },
+ nb::sig("def __repr__(self) -> str"),
+ "Returns a string representation of the dialect.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialectRegistry
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialectRegistry>(m, "DialectRegistry")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule,
+ "Gets a capsule wrapping the MlirDialectRegistry.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyDialectRegistry::createFromCapsule,
+ "Creates a DialectRegistry from a capsule wrapping "
+ "`MlirDialectRegistry`.")
+ .def(nb::init<>(), "Creates a new empty dialect registry.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Location
+ //----------------------------------------------------------------------------
+ nb::class_<PyLocation>(m, "Location")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule,
+ "Gets a capsule wrapping the MlirLocation.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule,
+ "Creates a Location from a capsule wrapping MlirLocation.")
+ .def("__enter__", &PyLocation::contextEnter,
+ "Enters the location as a context manager.")
+ .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
+ nb::arg("exc_value").none(), nb::arg("traceback").none(),
+ "Exits the location context manager.")
+ .def(
+ "__eq__",
+ [](PyLocation &self, PyLocation &other) -> bool {
+ return mlirLocationEqual(self, other);
+ },
+ "Compares two locations for equality.")
+ .def(
+ "__eq__", [](PyLocation &self, nb::object other) { return false; },
+ "Compares location with non-location object (always returns False).")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/) -> std::optional<PyLocation *> {
+ auto *loc = PyThreadContextEntry::getDefaultLocation();
+ if (!loc)
+ return std::nullopt;
+ return loc;
+ },
+ // clang-format off
+ nb::sig("def current(/) -> Location | None"),
+ // clang-format on
+ "Gets the Location bound to the current thread or raises ValueError.")
+ .def_static(
+ "unknown",
+ [](DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationUnknownGet(context->get()));
+ },
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing an unknown location.")
+ .def_static(
+ "callsite",
+ [](PyLocation callee, const std::vector<PyLocation> &frames,
+ DefaultingPyMlirContext context) {
+ if (frames.empty())
+ throw nb::value_error("No caller frames provided.");
+ MlirLocation caller = frames.back().get();
+ for (const PyLocation &frame :
+ llvm::reverse(llvm::ArrayRef(frames).drop_back()))
+ caller = mlirLocationCallSiteGet(frame.get(), caller);
+ return PyLocation(context->getRef(),
+ mlirLocationCallSiteGet(callee.get(), caller));
+ },
+ nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
+ "Gets a Location representing a caller and callsite.")
+ .def("is_a_callsite", mlirLocationIsACallSite,
+ "Returns True if this location is a CallSiteLoc.")
+ .def_prop_ro(
+ "callee",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationCallSiteGetCallee(self));
+ },
+ "Gets the callee location from a CallSiteLoc.")
+ .def_prop_ro(
+ "caller",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationCallSiteGetCaller(self));
+ },
+ "Gets the caller location from a CallSiteLoc.")
+ .def_static(
+ "file",
+ [](std::string filename, int line, int col,
+ DefaultingPyMlirContext context) {
+ return PyLocation(
+ context->getRef(),
+ mlirLocationFileLineColGet(
+ context->get(), toMlirStringRef(filename), line, col));
+ },
+ nb::arg("filename"), nb::arg("line"), nb::arg("col"),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a file, line and column.")
+ .def_static(
+ "file",
+ [](std::string filename, int startLine, int startCol, int endLine,
+ int endCol, DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationFileLineColRangeGet(
+ context->get(), toMlirStringRef(filename),
+ startLine, startCol, endLine, endCol));
+ },
+ nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
+ nb::arg("end_line"), nb::arg("end_col"),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a file, line and column range.")
+ .def("is_a_file", mlirLocationIsAFileLineColRange,
+ "Returns True if this location is a FileLineColLoc.")
+ .def_prop_ro(
+ "filename",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(
+ mlirLocationFileLineColRangeGetFilename(loc));
+ },
+ "Gets the filename from a FileLineColLoc.")
+ .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
+ "Gets the start line number from a `FileLineColLoc`.")
+ .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
+ "Gets the start column number from a `FileLineColLoc`.")
+ .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
+ "Gets the end line number from a `FileLineColLoc`.")
+ .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
+ "Gets the end column number from a `FileLineColLoc`.")
+ .def_static(
+ "fused",
+ [](const std::vector<PyLocation> &pyLocations,
+ std::optional<PyAttribute> metadata,
+ DefaultingPyMlirContext context) {
+ llvm::SmallVector<MlirLocation, 4> locations;
+ locations.reserve(pyLocations.size());
+ for (auto &pyLocation : pyLocations)
+ locations.push_back(pyLocation.get());
+ MlirLocation location = mlirLocationFusedGet(
+ context->get(), locations.size(), locations.data(),
+ metadata ? metadata->get() : MlirAttribute{0});
+ return PyLocation(context->getRef(), location);
+ },
+ nb::arg("locations"), nb::arg("metadata") = nb::none(),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a fused location with optional "
+ "metadata.")
+ .def("is_a_fused", mlirLocationIsAFused,
+ "Returns True if this location is a `FusedLoc`.")
+ .def_prop_ro(
+ "locations",
+ [](PyLocation &self) {
+ unsigned numLocations = mlirLocationFusedGetNumLocations(self);
+ std::vector<MlirLocation> locations(numLocations);
+ if (numLocations)
+ mlirLocationFusedGetLocations(self, locations.data());
+ std::vector<PyLocation> pyLocations{};
+ pyLocations.reserve(numLocations);
+ for (unsigned i = 0; i < numLocations; ++i)
+ pyLocations.emplace_back(self.getContext(), locations[i]);
+ return pyLocations;
+ },
+ "Gets the list of locations from a `FusedLoc`.")
+ .def_static(
+ "name",
+ [](std::string name, std::optional<PyLocation> childLoc,
+ DefaultingPyMlirContext context) {
+ return PyLocation(
+ context->getRef(),
+ mlirLocationNameGet(
+ context->get(), toMlirStringRef(name),
+ childLoc ? childLoc->get()
+ : mlirLocationUnknownGet(context->get())));
+ },
+ nb::arg("name"), nb::arg("childLoc") = nb::none(),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a named location with optional child "
+ "location.")
+ .def("is_a_name", mlirLocationIsAName,
+ "Returns True if this location is a `NameLoc`.")
+ .def_prop_ro(
+ "name_str",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(mlirLocationNameGetName(loc));
+ },
+ "Gets the name string from a `NameLoc`.")
+ .def_prop_ro(
+ "child_loc",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationNameGetChildLoc(self));
+ },
+ "Gets the child location from a `NameLoc`.")
+ .def_static(
+ "from_attr",
+ [](PyAttribute &attribute, DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationFromAttribute(attribute));
+ },
+ nb::arg("attribute"), nb::arg("context") = nb::none(),
+ "Gets a Location from a `LocationAttr`.")
+ .def_prop_ro(
+ "context",
+ [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Location`.")
+ .def_prop_ro(
+ "attr",
+ [](PyLocation &self) {
+ return PyAttribute(self.getContext(),
+ mlirLocationGetAttribute(self));
+ },
+ "Get the underlying `LocationAttr`.")
+ .def(
+ "emit_error",
+ [](PyLocation &self, std::string message) {
+ mlirEmitError(self, message.c_str());
+ },
+ nb::arg("message"),
+ R"(
+ Emits an error diagnostic at this location.
+
+ Args:
+ message: The error message to emit.)")
+ .def(
+ "__repr__",
+ [](PyLocation &self) {
+ PyPrintAccumulator printAccum;
+ mlirLocationPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly representation of the location.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Module
+ //----------------------------------------------------------------------------
+ nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule,
+ "Gets a capsule wrapping the MlirModule.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+ R"(
+ Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
+
+ This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
+ prevent double-frees (of the underlying `mlir::Module`).)")
+ .def("_clear_mlir_module", &PyModule::clearMlirModule,
+ R"(
+ Clears the internal MLIR module reference.
+
+ This is used internally to prevent double-free when ownership is transferred
+ via the C API capsule mechanism. Not intended for normal use.)")
+ .def_static(
+ "parse",
+ [](const std::string &moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParse(
+ context->get(), toMlirStringRef(moduleAsm));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ kModuleParseDocstring)
+ .def_static(
+ "parse",
+ [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParse(
+ context->get(), toMlirStringRef(moduleAsm));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ kModuleParseDocstring)
+ .def_static(
+ "parseFile",
+ [](const std::string &path, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParseFromFile(
+ context->get(), toMlirStringRef(path));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("path"), nb::arg("context") = nb::none(),
+ kModuleParseDocstring)
+ .def_static(
+ "create",
+ [](const std::optional<PyLocation> &loc)
+ -> nb::typed<nb::object, PyModule> {
+ PyLocation pyLoc = maybeGetTracebackLocation(loc);
+ MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("loc") = nb::none(), "Creates an empty module.")
+ .def_prop_ro(
+ "context",
+ [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that created the `Module`.")
+ .def_prop_ro(
+ "operation",
+ [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
+ return PyOperation::forOperation(self.getContext(),
+ mlirModuleGetOperation(self.get()),
+ self.getRef().releaseObject())
+ .releaseObject();
+ },
+ "Accesses the module as an operation.")
+ .def_prop_ro(
+ "body",
+ [](PyModule &self) {
+ PyOperationRef moduleOp = PyOperation::forOperation(
+ self.getContext(), mlirModuleGetOperation(self.get()),
+ self.getRef().releaseObject());
+ PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
+ return returnBlock;
+ },
+ "Return the block for this module.")
+ .def(
+ "dump",
+ [](PyModule &self) {
+ mlirOperationDump(mlirModuleGetOperation(self.get()));
+ },
+ kDumpDocstring)
+ .def(
+ "__str__",
+ [](const nb::object &self) {
+ // Defer to the operation's __str__.
+ return self.attr("operation").attr("__str__")();
+ },
+ nb::sig("def __str__(self) -> str"),
+ R"(
+ Gets the assembly form of the operation with default options.
+
+ If more advanced control over the assembly formatting or I/O options is needed,
+ use the dedicated print or get_asm method, which supports keyword arguments to
+ customize behavior.
+ )")
+ .def(
+ "__eq__",
+ [](PyModule &self, PyModule &other) {
+ return mlirModuleEqual(self.get(), other.get());
+ },
+ "other"_a, "Compares two modules for equality.")
+ .def(
+ "__hash__",
+ [](PyModule &self) { return mlirModuleHashValue(self.get()); },
+ "Returns the hash value of the module.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Operation.
+ //----------------------------------------------------------------------------
+ nb::class_<PyOperationBase>(m, "_OperationBase")
+ .def_prop_ro(
+ MLIR_PYTHON_CAPI_PTR_ATTR,
+ [](PyOperationBase &self) {
+ return self.getOperation().getCapsule();
+ },
+ "Gets a capsule wrapping the `MlirOperation`.")
+ .def(
+ "__eq__",
+ [](PyOperationBase &self, PyOperationBase &other) {
+ return mlirOperationEqual(self.getOperation().get(),
+ other.getOperation().get());
+ },
+ "Compares two operations for equality.")
+ .def(
+ "__eq__",
+ [](PyOperationBase &self, nb::object other) { return false; },
+ "Compares operation with non-operation object (always returns "
+ "False).")
+ .def(
+ "__hash__",
+ [](PyOperationBase &self) {
+ return mlirOperationHashValue(self.getOperation().get());
+ },
+ "Returns the hash value of the operation.")
+ .def_prop_ro(
+ "attributes",
+ [](PyOperationBase &self) {
+ return PyOpAttributeMap(self.getOperation().getRef());
+ },
+ "Returns a dictionary-like map of operation attributes.")
+ .def_prop_ro(
+ "context",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
+ PyOperation &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ return concreteOperation.getContext().getObject();
+ },
+ "Context that owns the operation.")
+ .def_prop_ro(
+ "name",
+ [](PyOperationBase &self) {
+ auto &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ MlirOperation operation = concreteOperation.get();
+ return mlirIdentifierStr(mlirOperationGetName(operation));
+ },
+ "Returns the fully qualified name of the operation.")
+ .def_prop_ro(
+ "operands",
+ [](PyOperationBase &self) {
+ return PyOpOperandList(self.getOperation().getRef());
+ },
+ "Returns the list of operation operands.")
+ .def_prop_ro(
+ "regions",
+ [](PyOperationBase &self) {
+ return PyRegionList(self.getOperation().getRef());
+ },
+ "Returns the list of operation regions.")
+ .def_prop_ro(
+ "results",
+ [](PyOperationBase &self) {
+ return PyOpResultList(self.getOperation().getRef());
+ },
+ "Returns the list of Operation results.")
+ .def_prop_ro(
+ "result",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
+ auto &operation = self.getOperation();
+ return PyOpResult(operation.getRef(), getUniqueResult(operation))
+ .maybeDownCast();
+ },
+ "Shortcut to get an op result if it has only one (throws an error "
+ "otherwise).")
+ .def_prop_rw(
+ "location",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ return PyLocation(operation.getContext(),
+ mlirOperationGetLocation(operation.get()));
+ },
+ [](PyOperationBase &self, const PyLocation &location) {
+ PyOperation &operation = self.getOperation();
+ mlirOperationSetLocation(operation.get(), location.get());
+ },
+ nb::for_getter("Returns the source location the operation was "
+ "defined or derived from."),
+ nb::for_setter("Sets the source location the operation was defined "
+ "or derived from."))
+ .def_prop_ro(
+ "parent",
+ [](PyOperationBase &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
+ auto parent = self.getOperation().getParentOperation();
+ if (parent)
+ return parent->getObject();
+ return {};
+ },
+ "Returns the parent operation, or `None` if at top level.")
+ .def(
+ "__str__",
+ [](PyOperationBase &self) {
+ return self.getAsm(/*binary=*/false,
+ /*largeElementsLimit=*/std::nullopt,
+ /*largeResourceLimit=*/std::nullopt,
+ /*enableDebugInfo=*/false,
+ /*prettyDebugInfo=*/false,
+ /*printGenericOpForm=*/false,
+ /*useLocalScope=*/false,
+ /*useNameLocAsPrefix=*/false,
+ /*assumeVerified=*/false,
+ /*skipRegions=*/false);
+ },
+ nb::sig("def __str__(self) -> str"),
+ "Returns the assembly form of the operation.")
+ .def("print",
+ nb::overload_cast<PyAsmState &, nb::object, bool>(
+ &PyOperationBase::print),
+ nb::arg("state"), nb::arg("file") = nb::none(),
+ nb::arg("binary") = false,
+ R"(
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ state: `AsmState` capturing the operation numbering and flags.
+ file: Optional file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
+ .def("print",
+ nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
+ bool, bool, bool, bool, bool, bool, nb::object,
+ bool, bool>(&PyOperationBase::print),
+ // Careful: Lots of arguments must match up with print method.
+ nb::arg("large_elements_limit") = nb::none(),
+ nb::arg("large_resource_limit") = nb::none(),
+ nb::arg("enable_debug_info") = false,
+ nb::arg("pretty_debug_info") = false,
+ nb::arg("print_generic_op_form") = false,
+ nb::arg("use_local_scope") = false,
+ nb::arg("use_name_loc_as_prefix") = false,
+ nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
+ nb::arg("binary") = false, nb::arg("skip_regions") = false,
+ R"(
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ large_elements_limit: Whether to elide elements attributes above this
+ number of elements. Defaults to None (no limit).
+ large_resource_limit: Whether to elide resource attributes above this
+ number of characters. Defaults to None (no limit). If large_elements_limit
+ is set and this is None, the behavior will be to use large_elements_limit
+ as large_resource_limit.
+ enable_debug_info: Whether to print debug/location information. Defaults
+ to False.
+ pretty_debug_info: Whether to format debug information for easier reading
+ by a human (warning: the result is unparseable). Defaults to False.
+ print_generic_op_form: Whether to print the generic assembly forms of all
+ ops. Defaults to False.
+ use_local_scope: Whether to print in a way that is more optimized for
+ multi-threaded access but may not be consistent with how the overall
+ module prints.
+ use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
+ prefixes for the SSA identifiers. Defaults to False.
+ assume_verified: By default, if not printing generic form, the verifier
+ will be run and if it fails, generic form will be printed with a comment
+ about failed verification. While a reasonable default for interactive use,
+ for systematic use, it is often better for the caller to verify explicitly
+ and report failures in a more robust fashion. Set this to True if doing this
+ in order to avoid running a redundant verification. If the IR is actually
+ invalid, behavior is undefined.
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ skip_regions: Whether to skip printing regions. Defaults to False.)")
+ .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
+ nb::arg("desired_version") = nb::none(),
+ R"(
+ Write the bytecode form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to.
+ desired_version: Optional version of bytecode to emit.
+ Returns:
+ The bytecode writer status.)")
+ .def("get_asm", &PyOperationBase::getAsm,
+ // Careful: Lots of arguments must match up with get_asm method.
+ nb::arg("binary") = false,
+ nb::arg("large_elements_limit") = nb::none(),
+ nb::arg("large_resource_limit") = nb::none(),
+ nb::arg("enable_debug_info") = false,
+ nb::arg("pretty_debug_info") = false,
+ nb::arg("print_generic_op_form") = false,
+ nb::arg("use_local_scope") = false,
+ nb::arg("use_name_loc_as_prefix") = false,
+ nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
+ R"(
+ Gets the assembly form of the operation with all options available.
+
+ Args:
+ binary: Whether to return a bytes (True) or str (False) object. Defaults to
+ False.
+ ... others ...: See the print() method for common keyword arguments for
+ configuring the printout.
+ Returns:
+ Either a bytes or str object, depending on the setting of the `binary`
+ argument.)")
+ .def("verify", &PyOperationBase::verify,
+ "Verify the operation. Raises MLIRError if verification fails, and "
+ "returns true otherwise.")
+ .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
+ "Puts self immediately after the other operation in its parent "
+ "block.")
+ .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
+ "Puts self immediately before the other operation in its parent "
+ "block.")
+ .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
+ nb::arg("other"),
+ R"(
+ Checks if this operation is before another in the same block.
+
+ Args:
+ other: Another operation in the same parent block.
+
+ Returns:
+ True if this operation is before `other` in the operation list of the parent block.)")
+ .def(
+ "clone",
+ [](PyOperationBase &self,
+ const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
+ return self.getOperation().clone(ip);
+ },
+ nb::arg("ip") = nb::none(),
+ R"(
+ Creates a deep copy of the operation.
+
+ Args:
+ ip: Optional insertion point where the cloned operation should be inserted.
+ If None, the current insertion point is used. If False, the operation
+ remains detached.
+
+ Returns:
+ A new Operation that is a clone of this operation.)")
+ .def(
+ "detach_from_parent",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ if (!operation.isAttached())
+ throw nb::value_error("Detached operation has no parent.");
+
+ operation.detachFromParent();
+ return operation.createOpView();
+ },
+ "Detaches the operation from its parent block.")
+ .def_prop_ro(
+ "attached",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ return operation.isAttached();
+ },
+ "Reports if the operation is attached to its parent block.")
+ .def(
+ "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
+ R"(
+ Erases the operation and frees its memory.
+
+ Note:
+ After erasing, any Python references to the operation become invalid.)")
+ .def("walk", &PyOperationBase::walk, nb::arg("callback"),
+ nb::arg("walk_order") = MlirWalkPostOrder,
+ // clang-format off
+ nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
+ // clang-format on
+ R"(
+ Walks the operation tree with a callback function.
+
+ Args:
+ callback: A callable that takes an Operation and returns a WalkResult.
+ walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+
+ nb::class_<PyOperation, PyOperationBase>(m, "Operation")
+ .def_static(
+ "create",
+ [](std::string_view name,
+ std::optional<std::vector<PyType *>> results,
+ std::optional<std::vector<PyValue *>> operands,
+ std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors, int regions,
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp,
+ bool inferType) -> nb::typed<nb::object, PyOperation> {
+ // Unpack/validate operands.
+ llvm::SmallVector<MlirValue, 4> mlirOperands;
+ if (operands) {
+ mlirOperands.reserve(operands->size());
+ for (PyValue *operand : *operands) {
+ if (!operand)
+ throw nb::value_error("operand value cannot be None");
+ mlirOperands.push_back(operand->get());
+ }
+ }
+
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ return PyOperation::create(name, results, mlirOperands, attributes,
+ successors, regions, pyLoc, maybeIp,
+ inferType);
+ },
+ nb::arg("name"), nb::arg("results") = nb::none(),
+ nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
+ nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
+ nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
+ nb::arg("infer_type") = false,
+ R"(
+ Creates a new operation.
+
+ Args:
+ name: Operation name (e.g. `dialect.operation`).
+ results: Optional sequence of Type representing op result types.
+ operands: Optional operands of the operation.
+ attributes: Optional Dict of {str: Attribute}.
+ successors: Optional List of Block for the operation's successors.
+ regions: Number of regions to create (default = 0).
+ location: Optional Location object (defaults to resolve from context manager).
+ ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
+ infer_type: Whether to infer result types (default = False).
+ Returns:
+ A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
+ .def_static(
+ "parse",
+ [](const std::string &sourceStr, const std::string &sourceName,
+ DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyOpView> {
+ return PyOperation::parse(context->getRef(), sourceStr, sourceName)
+ ->createOpView();
+ },
+ nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
+ nb::arg("context") = nb::none(),
+ "Parses an operation. Supports both text assembly format and binary "
+ "bytecode format.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule,
+ "Gets a capsule wrapping the MlirOperation.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyOperation::createFromCapsule,
+ "Creates an Operation from a capsule wrapping MlirOperation.")
+ .def_prop_ro(
+ "operation",
+ [](nb::object self) -> nb::typed<nb::object, PyOperation> {
+ return self;
+ },
+ "Returns self (the operation).")
+ .def_prop_ro(
+ "opview",
+ [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
+ return self.createOpView();
+ },
+ R"(
+ Returns an OpView of this operation.
+
+ Note:
+ If the operation has a registered and loaded dialect then this OpView will
+ be concrete wrapper class.)")
+ .def_prop_ro("block", &PyOperation::getBlock,
+ "Returns the block containing this operation.")
+ .def_prop_ro(
+ "successors",
+ [](PyOperationBase &self) {
+ return PyOpSuccessors(self.getOperation().getRef());
+ },
+ "Returns the list of Operation successors.")
+ .def("_set_invalid", &PyOperation::setInvalid,
+ "Invalidate the operation.");
+
+ auto opViewClass =
+ nb::class_<PyOpView, PyOperationBase>(m, "OpView")
+ .def(nb::init<nb::typed<nb::object, PyOperation>>(),
+ nb::arg("operation"))
+ .def(
+ "__init__",
+ [](PyOpView *self, std::string_view name,
+ std::tuple<int, bool> opRegionSpec,
+ nb::object operandSegmentSpecObj,
+ nb::object resultSegmentSpecObj,
+ std::optional<nb::list> resultTypeList, nb::list operandList,
+ std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors,
+ std::optional<int> regions,
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp) {
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ new (self) PyOpView(PyOpView::buildGeneric(
+ name, opRegionSpec, operandSegmentSpecObj,
+ resultSegmentSpecObj, resultTypeList, operandList,
+ attributes, successors, regions, pyLoc, maybeIp));
+ },
+ nb::arg("name"), nb::arg("opRegionSpec"),
+ nb::arg("operandSegmentSpecObj") = nb::none(),
+ nb::arg("resultSegmentSpecObj") = nb::none(),
+ nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
+ nb::arg("attributes") = nb::none(),
+ nb::arg("successors") = nb::none(),
+ nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
+ nb::arg("ip") = nb::none())
+ .def_prop_ro(
+ "operation",
+ [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
+ return self.getOperationObject();
+ })
+ .def_prop_ro("opview",
+ [](nb::object self) -> nb::typed<nb::object, PyOpView> {
+ return self;
+ })
+ .def(
+ "__str__",
+ [](PyOpView &self) { return nb::str(self.getOperationObject()); })
+ .def_prop_ro(
+ "successors",
+ [](PyOperationBase &self) {
+ return PyOpSuccessors(self.getOperation().getRef());
+ },
+ "Returns the list of Operation successors.")
+ .def(
+ "_set_invalid",
+ [](PyOpView &self) { self.getOperation().setInvalid(); },
+ "Invalidate the operation.");
+ opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
+ opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
+ opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
+ // It is faster to pass the operation_name, ods_regions, and
+ // ods_operand_segments/ods_result_segments as arguments to the constructor,
+ // rather than to access them as attributes.
+ opViewClass.attr("build_generic") = classmethod(
+ [](nb::handle cls, std::optional<nb::list> resultTypeList,
+ nb::list operandList, std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors,
+ std::optional<int> regions, std::optional<PyLocation> location,
+ const nb::object &maybeIp) {
+ std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
+ std::tuple<int, bool> opRegionSpec =
+ nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
+ nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
+ nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
+ resultSegmentSpec, resultTypeList,
+ operandList, attributes, successors,
+ regions, pyLoc, maybeIp);
+ },
+ nb::arg("cls"), nb::arg("results") = nb::none(),
+ nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
+ nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
+ nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
+ "Builds a specific, generated OpView based on class level attributes.");
+ opViewClass.attr("parse") = classmethod(
+ [](const nb::object &cls, const std::string &sourceStr,
+ const std::string &sourceName,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
+ PyOperationRef parsed =
+ PyOperation::parse(context->getRef(), sourceStr, sourceName);
+
+ // Check if the expected operation was parsed, and cast to to the
+ // appropriate `OpView` subclass if successful.
+ // NOTE: This accesses attributes that have been automatically added to
+ // `OpView` subclasses, and is not intended to be used on `OpView`
+ // directly.
+ std::string clsOpName =
+ nb::cast<std::string>(cls.attr("OPERATION_NAME"));
+ MlirStringRef identifier =
+ mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
+ std::string_view parsedOpName(identifier.data, identifier.length);
+ if (clsOpName != parsedOpName)
+ throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
+ parsedOpName + "'");
+ return PyOpView::constructDerived(cls, parsed.getObject());
+ },
+ nb::arg("cls"), nb::arg("source"), nb::kw_only(),
+ nb::arg("source_name") = "", nb::arg("context") = nb::none(),
+ "Parses a specific, generated OpView based on class level attributes.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyRegion.
+ //----------------------------------------------------------------------------
+ nb::class_<PyRegion>(m, "Region")
+ .def_prop_ro(
+ "blocks",
+ [](PyRegion &self) {
+ return PyBlockList(self.getParentOperation(), self.get());
+ },
+ "Returns a forward-optimized sequence of blocks.")
+ .def_prop_ro(
+ "owner",
+ [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the operation owning this region.")
+ .def(
+ "__iter__",
+ [](PyRegion &self) {
+ self.checkValid();
+ MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
+ return PyBlockIterator(self.getParentOperation(), firstBlock);
+ },
+ "Iterates over blocks in the region.")
+ .def(
+ "__eq__",
+ [](PyRegion &self, PyRegion &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two regions for pointer equality.")
+ .def(
+ "__eq__", [](PyRegion &self, nb::object &other) { return false; },
+ "Compares region with non-region object (always returns False).");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyBlock.
+ //----------------------------------------------------------------------------
+ nb::class_<PyBlock>(m, "Block")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule,
+ "Gets a capsule wrapping the MlirBlock.")
+ .def_prop_ro(
+ "owner",
+ [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the owning operation of this block.")
+ .def_prop_ro(
+ "region",
+ [](PyBlock &self) {
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ return PyRegion(self.getParentOperation(), region);
+ },
+ "Returns the owning region of this block.")
+ .def_prop_ro(
+ "arguments",
+ [](PyBlock &self) {
+ return PyBlockArgumentList(self.getParentOperation(), self.get());
+ },
+ "Returns a list of block arguments.")
+ .def(
+ "add_argument",
+ [](PyBlock &self, const PyType &type, const PyLocation &loc) {
+ return PyBlockArgument(self.getParentOperation(),
+ mlirBlockAddArgument(self.get(), type, loc));
+ },
+ "type"_a, "loc"_a,
+ R"(
+ Appends an argument of the specified type to the block.
+
+ Args:
+ type: The type of the argument to add.
+ loc: The source location for the argument.
+
+ Returns:
+ The newly added block argument.)")
+ .def(
+ "erase_argument",
+ [](PyBlock &self, unsigned index) {
+ return mlirBlockEraseArgument(self.get(), index);
+ },
+ nb::arg("index"),
+ R"(
+ Erases the argument at the specified index.
+
+ Args:
+ index: The index of the argument to erase.)")
+ .def_prop_ro(
+ "operations",
+ [](PyBlock &self) {
+ return PyOperationList(self.getParentOperation(), self.get());
+ },
+ "Returns a forward-optimized sequence of operations.")
+ .def_static(
+ "create_at_start",
+ [](PyRegion &parent, const nb::sequence &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ parent.checkValid();
+ MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
+ mlirRegionInsertOwnedBlock(parent, 0, block);
+ return PyBlock(parent.getParentOperation(), block);
+ },
+ nb::arg("parent"), nb::arg("arg_types") = nb::list(),
+ nb::arg("arg_locs") = std::nullopt,
+ "Creates and returns a new Block at the beginning of the given "
+ "region (with given argument types and locations).")
+ .def(
+ "append_to",
+ [](PyBlock &self, PyRegion ®ion) {
+ MlirBlock b = self.get();
+ if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
+ mlirBlockDetach(b);
+ mlirRegionAppendOwnedBlock(region.get(), b);
+ },
+ nb::arg("region"),
+ R"(
+ Appends this block to a region.
+
+ Transfers ownership if the block is currently owned by another region.
+
+ Args:
+ region: The region to append the block to.)")
+ .def(
+ "create_before",
+ [](PyBlock &self, const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ self.checkValid();
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
+ return PyBlock(self.getParentOperation(), block);
+ },
+ nb::arg("arg_types"), nb::kw_only(),
+ nb::arg("arg_locs") = std::nullopt,
+ "Creates and returns a new Block before this block "
+ "(with given argument types and locations).")
+ .def(
+ "create_after",
+ [](PyBlock &self, const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ self.checkValid();
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
+ return PyBlock(self.getParentOperation(), block);
+ },
+ nb::arg("arg_types"), nb::kw_only(),
+ nb::arg("arg_locs") = std::nullopt,
+ "Creates and returns a new Block after this block "
+ "(with given argument types and locations).")
+ .def(
+ "__iter__",
+ [](PyBlock &self) {
+ self.checkValid();
+ MlirOperation firstOperation =
+ mlirBlockGetFirstOperation(self.get());
+ return PyOperationIterator(self.getParentOperation(),
+ firstOperation);
+ },
+ "Iterates over operations in the block.")
+ .def(
+ "__eq__",
+ [](PyBlock &self, PyBlock &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two blocks for pointer equality.")
+ .def(
+ "__eq__", [](PyBlock &self, nb::object &other) { return false; },
+ "Compares block with non-block object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyBlock &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the block.")
+ .def(
+ "__str__",
+ [](PyBlock &self) {
+ self.checkValid();
+ PyPrintAccumulator printAccum;
+ mlirBlockPrint(self.get(), printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the block.")
+ .def(
+ "append",
+ [](PyBlock &self, PyOperationBase &operation) {
+ if (operation.getOperation().isAttached())
+ operation.getOperation().detachFromParent();
+
+ MlirOperation mlirOperation = operation.getOperation().get();
+ mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
+ operation.getOperation().setAttached(
+ self.getParentOperation().getObject());
+ },
+ nb::arg("operation"),
+ R"(
+ Appends an operation to this block.
+
+ If the operation is currently in another block, it will be moved.
+
+ Args:
+ operation: The operation to append to the block.)")
+ .def_prop_ro(
+ "successors",
+ [](PyBlock &self) {
+ return PyBlockSuccessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block successors.")
+ .def_prop_ro(
+ "predecessors",
+ [](PyBlock &self) {
+ return PyBlockPredecessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block predecessors.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyInsertionPoint.
+ //----------------------------------------------------------------------------
+
+ nb::class_<PyInsertionPoint>(m, "InsertionPoint")
+ .def(nb::init<PyBlock &>(), nb::arg("block"),
+ "Inserts after the last operation but still inside the block.")
+ .def("__enter__", &PyInsertionPoint::contextEnter,
+ "Enters the insertion point as a context manager.")
+ .def("__exit__", &PyInsertionPoint::contextExit,
+ nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+ nb::arg("traceback").none(),
+ "Exits the insertion point context manager.")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/) {
+ auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
+ if (!ip)
+ throw nb::value_error("No current InsertionPoint");
+ return ip;
+ },
+ nb::sig("def current(/) -> InsertionPoint"),
+ "Gets the InsertionPoint bound to the current thread or raises "
+ "ValueError if none has been set.")
+ .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
+ "Inserts before a referenced operation.")
+ .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
+ nb::arg("block"),
+ R"(
+ Creates an insertion point at the beginning of a block.
+
+ Args:
+ block: The block at whose beginning operations should be inserted.
+
+ Returns:
+ An InsertionPoint at the block's beginning.)")
+ .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
+ nb::arg("block"),
+ R"(
+ Creates an insertion point before a block's terminator.
+
+ Args:
+ block: The block whose terminator to insert before.
+
+ Returns:
+ An InsertionPoint before the terminator.
+
+ Raises:
+ ValueError: If the block has no terminator.)")
+ .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
+ R"(
+ Creates an insertion point immediately after an operation.
+
+ Args:
+ operation: The operation after which to insert.
+
+ Returns:
+ An InsertionPoint after the operation.)")
+ .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
+ R"(
+ Inserts an operation at this insertion point.
+
+ Args:
+ operation: The operation to insert.)")
+ .def_prop_ro(
+ "block", [](PyInsertionPoint &self) { return self.getBlock(); },
+ "Returns the block that this `InsertionPoint` points to.")
+ .def_prop_ro(
+ "ref_operation",
+ [](PyInsertionPoint &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
+ auto refOperation = self.getRefOperation();
+ if (refOperation)
+ return refOperation->getObject();
+ return {};
+ },
+ "The reference operation before which new operations are "
+ "inserted, or None if the insertion point is at the end of "
+ "the block.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyAttribute.
+ //----------------------------------------------------------------------------
+ nb::class_<PyAttribute>(m, "Attribute")
+ // Delegate to the PyAttribute copy constructor, which will also lifetime
+ // extend the backing context which owns the MlirAttribute.
+ .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
+ "Casts the passed attribute to the generic `Attribute`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule,
+ "Gets a capsule wrapping the MlirAttribute.")
+ .def_static(
+ MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule,
+ "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
+ .def_static(
+ "parse",
+ [](const std::string &attrSpec, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyAttribute> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr = mlirAttributeParseGet(
+ context->get(), toMlirStringRef(attrSpec));
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Unable to parse attribute", errors.take());
+ return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ "Parses an attribute from an assembly form. Raises an `MLIRError` on "
+ "failure.")
+ .def_prop_ro(
+ "context",
+ [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Attribute`.")
+ .def_prop_ro(
+ "type",
+ [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirAttributeGetType(self))
+ .maybeDownCast();
+ },
+ "Returns the type of the `Attribute`.")
+ .def(
+ "get_named",
+ [](PyAttribute &self, std::string name) {
+ return PyNamedAttribute(self, std::move(name));
+ },
+ nb::keep_alive<0, 1>(),
+ R"(
+ Binds a name to the attribute, creating a `NamedAttribute`.
+
+ Args:
+ name: The name to bind to the `Attribute`.
+
+ Returns:
+ A `NamedAttribute` with the given name and this attribute.)")
+ .def(
+ "__eq__",
+ [](PyAttribute &self, PyAttribute &other) { return self == other; },
+ "Compares two attributes for equality.")
+ .def(
+ "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
+ "Compares attribute with non-attribute object (always returns "
+ "False).")
+ .def(
+ "__hash__",
+ [](PyAttribute &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the attribute.")
+ .def(
+ "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
+ kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyAttribute &self) {
+ PyPrintAccumulator printAccum;
+ mlirAttributePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the Attribute.")
+ .def(
+ "__repr__",
+ [](PyAttribute &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, attribute values are generally considered useful and
+ // are printed. This may need to be re-evaluated if debug dumps end
+ // up being excessive.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Attribute(");
+ mlirAttributePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the attribute.")
+ .def_prop_ro(
+ "typeid",
+ [](PyAttribute &self) {
+ MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
+ assert(!mlirTypeIDIsNull(mlirTypeID) &&
+ "mlirTypeID was expected to be non-null.");
+ return PyTypeID(mlirTypeID);
+ },
+ "Returns the `TypeID` of the attribute.")
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the attribute to a more specific attribute if possible.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyNamedAttribute
+ //----------------------------------------------------------------------------
+ nb::class_<PyNamedAttribute>(m, "NamedAttribute")
+ .def(
+ "__repr__",
+ [](PyNamedAttribute &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("NamedAttribute(");
+ printAccum.parts.append(
+ nb::str(mlirIdentifierStr(self.namedAttr.name).data,
+ mlirIdentifierStr(self.namedAttr.name).length));
+ printAccum.parts.append("=");
+ mlirAttributePrint(self.namedAttr.attribute,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the named attribute.")
+ .def_prop_ro(
+ "name",
+ [](PyNamedAttribute &self) {
+ return mlirIdentifierStr(self.namedAttr.name);
+ },
+ "The name of the `NamedAttribute` binding.")
+ .def_prop_ro(
+ "attr",
+ [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
+ nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
+ "The underlying generic attribute of the `NamedAttribute` binding.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyType.
+ //----------------------------------------------------------------------------
+ nb::class_<PyType>(m, "Type")
+ // Delegate to the PyType copy constructor, which will also lifetime
+ // extend the backing context which owns the MlirType.
+ .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
+ "Casts the passed type to the generic `Type`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule,
+ "Gets a capsule wrapping the `MlirType`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule,
+ "Creates a Type from a capsule wrapping `MlirType`.")
+ .def_static(
+ "parse",
+ [](std::string typeSpec,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type =
+ mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Unable to parse type", errors.take());
+ return PyType(context.get()->getRef(), type).maybeDownCast();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ R"(
+ Parses the assembly form of a type.
+
+ Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
+
+ See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
+ .def_prop_ro(
+ "context",
+ [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Type`.")
+ .def(
+ "__eq__", [](PyType &self, PyType &other) { return self == other; },
+ "Compares two types for equality.")
+ .def(
+ "__eq__", [](PyType &self, nb::object &other) { return false; },
+ nb::arg("other").none(),
+ "Compares type with non-type object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyType &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the `Type`.")
+ .def(
+ "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyType &self) {
+ PyPrintAccumulator printAccum;
+ mlirTypePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the `Type`.")
+ .def(
+ "__repr__",
+ [](PyType &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, types are an exception as they typically have compact
+ // assembly forms and printing them is useful.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Type(");
+ mlirTypePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the `Type`.")
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyType &self) -> nb::typed<nb::object, PyType> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the Type to a more specific `Type` if possible.")
+ .def_prop_ro(
+ "typeid",
+ [](PyType &self) {
+ MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
+ if (!mlirTypeIDIsNull(mlirTypeID))
+ return PyTypeID(mlirTypeID);
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
+ throw nb::value_error(
+ (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
+ },
+ "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
+ "`Type` has no "
+ "`TypeID`.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyTypeID.
+ //----------------------------------------------------------------------------
+ nb::class_<PyTypeID>(m, "TypeID")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule,
+ "Gets a capsule wrapping the `MlirTypeID`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule,
+ "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
+ // Note, this tests whether the underlying TypeIDs are the same,
+ // not whether the wrapper MlirTypeIDs are the same, nor whether
+ // the Python objects are the same (i.e., PyTypeID is a value type).
+ .def(
+ "__eq__",
+ [](PyTypeID &self, PyTypeID &other) { return self == other; },
+ "Compares two `TypeID`s for equality.")
+ .def(
+ "__eq__",
+ [](PyTypeID &self, const nb::object &other) { return false; },
+ "Compares TypeID with non-TypeID object (always returns False).")
+ // Note, this gives the hash value of the underlying TypeID, not the
+ // hash value of the Python object, nor the hash value of the
+ // MlirTypeID wrapper.
+ .def(
+ "__hash__",
+ [](PyTypeID &self) {
+ return static_cast<size_t>(mlirTypeIDHashValue(self));
+ },
+ "Returns the hash value of the `TypeID`.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Value.
+ //----------------------------------------------------------------------------
+ m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
+
+ nb::class_<PyValue>(m, "Value", nb::is_generic(),
+ nb::sig("class Value(Generic[_T])"))
+ .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
+ "Creates a Value reference from another `Value`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
+ "Gets a capsule wrapping the `MlirValue`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule,
+ "Creates a `Value` from a capsule wrapping `MlirValue`.")
+ .def_prop_ro(
+ "context",
+ [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getParentOperation()->getContext().getObject();
+ },
+ "Context in which the value lives.")
+ .def(
+ "dump", [](PyValue &self) { mlirValueDump(self.get()); },
+ kDumpDocstring)
+ .def_prop_ro(
+ "owner",
+ [](PyValue &self) -> nb::object {
+ MlirValue v = self.get();
+ if (mlirValueIsAOpResult(v)) {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match "
+ "that in "
+ "the IR");
+ return self.getParentOperation().getObject();
+ }
+
+ if (mlirValueIsABlockArgument(v)) {
+ MlirBlock block = mlirBlockArgumentGetOwner(self.get());
+ return nb::cast(PyBlock(self.getParentOperation(), block));
+ }
+
+ assert(false && "Value must be a block argument or an op result");
+ return nb::none();
+ },
+ "Returns the owner of the value (`Operation` for results, `Block` "
+ "for "
+ "arguments).")
+ .def_prop_ro(
+ "uses",
+ [](PyValue &self) {
+ return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
+ },
+ "Returns an iterator over uses of this value.")
+ .def(
+ "__eq__",
+ [](PyValue &self, PyValue &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two values for pointer equality.")
+ .def(
+ "__eq__", [](PyValue &self, nb::object other) { return false; },
+ "Compares value with non-value object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyValue &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the value.")
+ .def(
+ "__str__",
+ [](PyValue &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Value(");
+ mlirValuePrint(self.get(), printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ R"(
+ Returns the string form of the value.
+
+ If the value is a block argument, this is the assembly form of its type and the
+ position in the argument list. If the value is an operation result, this is
+ equivalent to printing the operation that produced it.
+ )")
+ .def(
+ "get_name",
+ [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
+ PyPrintAccumulator printAccum;
+ MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ if (useNameLocAsPrefix)
+ mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
+ MlirAsmState valueState =
+ mlirAsmStateCreateForValue(self.get(), flags);
+ mlirValuePrintAsOperand(self.get(), valueState,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ mlirOpPrintingFlagsDestroy(flags);
+ mlirAsmStateDestroy(valueState);
+ return printAccum.join();
+ },
+ nb::arg("use_local_scope") = false,
+ nb::arg("use_name_loc_as_prefix") = false,
+ R"(
+ Returns the string form of value as an operand.
+
+ Args:
+ use_local_scope: Whether to use local scope for naming.
+ use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
+
+ Returns:
+ The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
+ .def(
+ "get_name",
+ [](PyValue &self, PyAsmState &state) {
+ PyPrintAccumulator printAccum;
+ MlirAsmState valueState = state.get();
+ mlirValuePrintAsOperand(self.get(), valueState,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ nb::arg("state"),
+ "Returns the string form of value as an operand (i.e., the ValueID).")
+ .def_prop_ro(
+ "type",
+ [](PyValue &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getParentOperation()->getContext(),
+ mlirValueGetType(self.get()))
+ .maybeDownCast();
+ },
+ "Returns the type of the value.")
+ .def(
+ "set_type",
+ [](PyValue &self, const PyType &type) {
+ mlirValueSetType(self.get(), type);
+ },
+ nb::arg("type"), "Sets the type of the value.",
+ nb::sig("def set_type(self, type: _T)"))
+ .def(
+ "replace_all_uses_with",
+ [](PyValue &self, PyValue &with) {
+ mlirValueReplaceAllUsesOfWith(self.get(), with.get());
+ },
+ "Replace all uses of value with the new value, updating anything in "
+ "the IR that uses `self` to use the other value instead.")
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, PyOperation &exception) {
+ MlirOperation exceptedUser = exception.get();
+ mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, const nb::list &exceptions) {
+ // Convert Python list to a SmallVector of MlirOperations
+ llvm::SmallVector<MlirOperation> exceptionOps;
+ for (nb::handle exception : exceptions) {
+ exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
+ }
+
+ mlirValueReplaceAllUsesExcept(
+ self, with, static_cast<intptr_t>(exceptionOps.size()),
+ exceptionOps.data());
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, PyOperation &exception) {
+ MlirOperation exceptedUser = exception.get();
+ mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with,
+ std::vector<PyOperation> &exceptions) {
+ // Convert Python list to a SmallVector of MlirOperations
+ llvm::SmallVector<MlirOperation> exceptionOps;
+ for (PyOperation &exception : exceptions)
+ exceptionOps.push_back(exception);
+ mlirValueReplaceAllUsesExcept(
+ self, with, static_cast<intptr_t>(exceptionOps.size()),
+ exceptionOps.data());
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyValue &self) -> nb::typed<nb::object, PyValue> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the `Value` to a more specific kind if possible.")
+ .def_prop_ro(
+ "location",
+ [](MlirValue self) {
+ return PyLocation(
+ PyMlirContext::forContext(mlirValueGetContext(self)),
+ mlirValueGetLocation(self));
+ },
+ "Returns the source location of the value.");
+
+ PyBlockArgument::bind(m);
+ PyOpResult::bind(m);
+ PyOpOperand::bind(m);
+
+ nb::class_<PyAsmState>(m, "AsmState")
+ .def(nb::init<PyValue &, bool>(), nb::arg("value"),
+ nb::arg("use_local_scope") = false,
+ R"(
+ Creates an `AsmState` for consistent SSA value naming.
+
+ Args:
+ value: The value to create state for.
+ use_local_scope: Whether to use local scope for naming.)")
+ .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
+ nb::arg("use_local_scope") = false,
+ R"(
+ Creates an AsmState for consistent SSA value naming.
+
+ Args:
+ op: The operation to create state for.
+ use_local_scope: Whether to use local scope for naming.)");
+
+ //----------------------------------------------------------------------------
+ // Mapping of SymbolTable.
+ //----------------------------------------------------------------------------
+ nb::class_<PySymbolTable>(m, "SymbolTable")
+ .def(nb::init<PyOperationBase &>(),
+ R"(
+ Creates a symbol table for an operation.
+
+ Args:
+ operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
+
+ Raises:
+ TypeError: If the operation is not a symbol table.)")
+ .def(
+ "__getitem__",
+ [](PySymbolTable &self,
+ const std::string &name) -> nb::typed<nb::object, PyOpView> {
+ return self.dunderGetItem(name);
+ },
+ R"(
+ Looks up a symbol by name in the symbol table.
+
+ Args:
+ name: The name of the symbol to look up.
+
+ Returns:
+ The operation defining the symbol.
+
+ Raises:
+ KeyError: If the symbol is not found.)")
+ .def("insert", &PySymbolTable::insert, nb::arg("operation"),
+ R"(
+ Inserts a symbol operation into the symbol table.
+
+ Args:
+ operation: An operation with a symbol name to insert.
+
+ Returns:
+ The symbol name attribute of the inserted operation.
+
+ Raises:
+ ValueError: If the operation does not have a symbol name.)")
+ .def("erase", &PySymbolTable::erase, nb::arg("operation"),
+ R"(
+ Erases a symbol operation from the symbol table.
+
+ Args:
+ operation: The symbol operation to erase.
+
+ Note:
+ The operation is also erased from the IR and invalidated.)")
+ .def("__delitem__", &PySymbolTable::dunderDel,
+ "Deletes a symbol by name from the symbol table.")
+ .def(
+ "__contains__",
+ [](PySymbolTable &table, const std::string &name) {
+ return !mlirOperationIsNull(mlirSymbolTableLookup(
+ table, mlirStringRefCreate(name.data(), name.length())));
+ },
+ "Checks if a symbol with the given name exists in the table.")
+ // Static helpers.
+ .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
+ nb::arg("symbol"), nb::arg("name"),
+ "Sets the symbol name for a symbol operation.")
+ .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
+ nb::arg("symbol"),
+ "Gets the symbol name from a symbol operation.")
+ .def_static("get_visibility", &PySymbolTable::getVisibility,
+ nb::arg("symbol"),
+ "Gets the visibility attribute of a symbol operation.")
+ .def_static("set_visibility", &PySymbolTable::setVisibility,
+ nb::arg("symbol"), nb::arg("visibility"),
+ "Sets the visibility attribute of a symbol operation.")
+ .def_static("replace_all_symbol_uses",
+ &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
+ nb::arg("new_symbol"), nb::arg("from_op"),
+ "Replaces all uses of a symbol with a new symbol name within "
+ "the given operation.")
+ .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
+ nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
+ nb::arg("callback"),
+ "Walks symbol tables starting from an operation with a "
+ "callback function.");
+
+ // Container bindings.
+ PyBlockArgumentList::bind(m);
+ PyBlockIterator::bind(m);
+ PyBlockList::bind(m);
+ PyBlockSuccessors::bind(m);
+ PyBlockPredecessors::bind(m);
+ PyOperationIterator::bind(m);
+ PyOperationList::bind(m);
+ PyOpAttributeMap::bind(m);
+ PyOpOperandIterator::bind(m);
+ PyOpOperandList::bind(m);
+ PyOpResultList::bind(m);
+ PyOpSuccessors::bind(m);
+ PyRegionIterator::bind(m);
+ PyRegionList::bind(m);
+
+ // Debug bindings.
+ PyGlobalDebugFlag::bind(m);
+
+ // Attribute builder getter.
+ PyAttrBuilderMap::bind(m);
+
+ // nb::register_exception_translator([](const std::exception_ptr &p,
+ // void *payload) {
+ // // We can't define exceptions with custom fields through pybind, so
+ // instead
+ // // the exception class is defined in python and imported here.
+ // try {
+ // if (p)
+ // std::rethrow_exception(p);
+ // } catch (const MLIRError &e) {
+ // nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ // .attr("MLIRError")(e.message, e.errorDiagnostics);
+ // PyErr_SetObject(PyExc_Exception, obj.ptr());
+ // }
+ // });
+}
+
+namespace mlir::python {
+void populateIRAffine(nb::module_ &m);
+void populateIRAttributes(nb::module_ &m);
+void populateIRInterfaces(nb::module_ &m);
+void populateIRTypes(nb::module_ &m);
+void registerMLIRErrorInIRCore();
+} // namespace mlir::python
+
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
@@ -158,4 +2415,18 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
+ registerMLIRErrorInIRCore();
+ nb::register_exception_translator([](const std::exception_ptr &p,
+ void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 572afa902746d..b1fd48cf410af 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,8 +8,8 @@
#include "Pass.h"
-#include "Globals.h"
-#include "IRModule.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/Pass.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..0221bd10e723e 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_PASS_H
#define MLIR_BINDINGS_PYTHON_PASS_H
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace mlir {
namespace python {
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0df9d0cbc7ffc..84f97abbae569 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,7 +8,7 @@
#include "Rewrite.h"
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index ae89e2b9589f1..f8ffdc7bdc458 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace mlir {
namespace python {
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index e9b1aff0455e6..3a4af1f066298 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -532,6 +532,10 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
MainModule.cpp
+ IRAffine.cpp
+ IRAttributes.cpp
+ IRInterfaces.cpp
+ IRTypes.cpp
Pass.cpp
Rewrite.cpp
@@ -991,12 +995,8 @@ get_property(NB_LIBRARY_TARGET_NAME TARGET MLIRPythonModules.extension._mlir.dso
list(GET NB_LIBRARY_TARGET_NAME 0 NB_LIBRARY_TARGET_NAME)
add_mlir_library_install(${NB_LIBRARY_TARGET_NAME})
add_mlir_library(MLIRPythonSupport
- ${PYTHON_SOURCE_DIR}/Globals.cpp
- ${PYTHON_SOURCE_DIR}/IRAffine.cpp
- ${PYTHON_SOURCE_DIR}/IRAttributes.cpp
${PYTHON_SOURCE_DIR}/IRCore.cpp
- ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp
- ${PYTHON_SOURCE_DIR}/IRTypes.cpp
+ ${PYTHON_SOURCE_DIR}/Globals.cpp
EXCLUDE_FROM_LIBMLIR
SHARED
LINK_COMPONENTS
@@ -1014,6 +1014,13 @@ set_target_properties(MLIRPythonSupport PROPERTIES
RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
)
+set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
+ LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ # Needed for windows (and doesn't hurt others).
+ RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+)
set(eh_rtti_enable)
if(MSVC)
set(eh_rtti_enable /EHsc /GR)
@@ -1035,4 +1042,4 @@ endif()
target_link_libraries(
MLIRPythonModules.extension._mlir.dso
PUBLIC MLIRPythonSupport)
-
+target_compile_definitions(MLIRPythonSupport PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
>From a11de9246f9988d4c943d61da9b139b303f80fb4 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 11 Dec 2025 13:46:11 -0800
Subject: [PATCH 03/27] works
---
mlir/cmake/modules/AddMLIRPython.cmake | 1 -
mlir/include/mlir/Bindings/Python/Globals.h | 1 -
mlir/include/mlir/Bindings/Python/IRCore.h | 19 +++++++++++
mlir/lib/Bindings/Python/IRAttributes.cpp | 14 +-------
mlir/lib/Bindings/Python/IRCore.cpp | 5 +--
mlir/lib/Bindings/Python/IRInterfaces.cpp | 2 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 18 ++---------
mlir/lib/Bindings/Python/MainModule.cpp | 32 ++-----------------
mlir/lib/Bindings/Python/Pass.cpp | 3 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 2 +-
mlir/python/CMakeLists.txt | 12 +++++--
.../python/lib/PythonTestModuleNanobind.cpp | 31 +++++++++++-------
12 files changed, 60 insertions(+), 80 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 882781736b493..47b7a58fc821d 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -601,7 +601,6 @@ function(add_mlir_python_common_capi_library name)
# Generate the aggregate .so that everything depends on.
add_mlir_aggregate(${name}
SHARED
- DISABLE_INSTALL
EMBED_LIBS ${_embed_libs}
)
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index fea7a201453ce..19ffe8164d727 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -245,7 +245,6 @@ struct PyGlobalDebugFlag {
static nanobind::ft_mutex mutex;
};
-
} // namespace python
} // namespace mlir
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 488196ea42e44..66a6272eaaf68 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1325,6 +1325,25 @@ struct MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
+inline void registerMLIRError() {
+ nanobind::register_exception_translator(
+ [](const std::exception_ptr &p, void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nanobind::object obj =
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
+}
+
+void registerMLIRErrorInCore();
+
//------------------------------------------------------------------------------
// Utilities.
//------------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 36367e658697c..4323374a5d5b7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1852,18 +1852,6 @@ void populateIRAttributes(nb::module_ &m) {
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
+ registerMLIRError();
}
} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 88cffb64906d7..ea1e62b8165ad 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -920,9 +920,6 @@ nb::object PyOperation::create(std::string_view name,
PyMlirContext::ErrorCapture errors(location.getContext());
MlirOperation operation = mlirOperationCreate(&state);
if (!operation.ptr) {
- for (auto take : errors.take()) {
- std::cout << take.message << "\n";
- }
throw MLIRError("Operation creation failed", errors.take());
}
PyOperationRef created =
@@ -1672,7 +1669,7 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
}
-void registerMLIRErrorInIRCore() {
+void registerMLIRErrorInCore() {
nb::register_exception_translator([](const std::exception_ptr &p,
void *payload) {
// We can't define exceptions with custom fields through pybind, so
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index f1e494c375523..78d1f977b2ebc 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,11 +12,11 @@
#include <utility>
#include <vector>
-#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 294ab91a059e2..7d9a0f16c913a 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -13,10 +13,10 @@
#include <optional>
-#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
using namespace mlir;
@@ -1175,18 +1175,6 @@ void populateIRTypes(nb::module_ &m) {
PyTupleType::bind(m);
PyFunctionType::bind(m);
PyOpaqueType::bind(m);
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
-});
-}
+ registerMLIRError();
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 686c55ee1e6a8..643851fcaf046 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -2250,21 +2250,6 @@ static void populateIRCore(nb::module_ &m) {
// Attribute builder getter.
PyAttrBuilderMap::bind(m);
-
- // nb::register_exception_translator([](const std::exception_ptr &p,
- // void *payload) {
- // // We can't define exceptions with custom fields through pybind, so
- // instead
- // // the exception class is defined in python and imported here.
- // try {
- // if (p)
- // std::rethrow_exception(p);
- // } catch (const MLIRError &e) {
- // nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- // .attr("MLIRError")(e.message, e.errorDiagnostics);
- // PyErr_SetObject(PyExc_Exception, obj.ptr());
- // }
- // });
}
namespace mlir::python {
@@ -2272,7 +2257,6 @@ void populateIRAffine(nb::module_ &m);
void populateIRAttributes(nb::module_ &m);
void populateIRInterfaces(nb::module_ &m);
void populateIRTypes(nb::module_ &m);
-void registerMLIRErrorInIRCore();
} // namespace mlir::python
// -----------------------------------------------------------------------------
@@ -2415,18 +2399,6 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
- registerMLIRErrorInIRCore();
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
+ registerMLIRError();
+ registerMLIRErrorInCore();
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index b1fd48cf410af..3cfdfe49b4e3e 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,9 +8,9 @@
#include "Pass.h"
+#include "mlir-c/Pass.h"
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
-#include "mlir-c/Pass.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
@@ -254,4 +254,5 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
},
"Print the textual representation for this PassManager, suitable to "
"be passed to `parse` for round-tripping.");
+ registerMLIRError();
}
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 84f97abbae569..4700120422ddc 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,10 +8,10 @@
#include "Rewrite.h"
-#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 3a4af1f066298..9286cead1a5c7 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -1002,11 +1002,16 @@ add_mlir_library(MLIRPythonSupport
LINK_COMPONENTS
Support
LINK_LIBS
+ Python::Module
${NB_LIBRARY_TARGET_NAME}
- MLIRCAPIIR
+ MLIRPythonCAPI
)
-target_link_libraries(MLIRPythonSupport PUBLIC ${NB_LIBRARY_TARGET_NAME})
nanobind_link_options(MLIRPythonSupport)
+get_target_property(_current_link_options MLIRPythonSupport LINK_OPTIONS)
+if(_current_link_options)
+ string(REPLACE "LINKER:-z,defs" "" _modified_link_options "${_current_link_options}")
+ set_property(TARGET MLIRPythonSupport PROPERTY LINK_OPTIONS "${_modified_link_options}")
+endif()
set_target_properties(MLIRPythonSupport PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
@@ -1042,4 +1047,7 @@ endif()
target_link_libraries(
MLIRPythonModules.extension._mlir.dso
PUBLIC MLIRPythonSupport)
+target_link_libraries(
+ MLIRPythonModules.extension._mlirPythonTestNanobind.dso
+ PUBLIC MLIRPythonSupport)
target_compile_definitions(MLIRPythonSupport PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index a497fcccf13d7..e53f1ab3b4d3f 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -14,6 +14,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
@@ -26,6 +27,24 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
+struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirPythonTestTestTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](mlir::python::DefaultingPyMlirContext context) {
+ return PyTestType(context->getRef(),
+ mlirPythonTestTestTypeGet(context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
@@ -78,17 +97,7 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
// clang-format on
nb::arg("cls"), nb::arg("context").none() = nb::none());
- mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
- mlirPythonTestTestTypeGetTypeID)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPythonTestTestTypeGet(ctx));
- },
- // clang-format off
- nb::sig("def get(cls: object, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("context").none() = nb::none());
+ PyTestType::bind(m);
auto typeCls =
mlir_type_subclass(m, "TestIntegerRankedTensorType",
>From 162afde1375f49f0eb1c786ed155c70a5f46492f Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 11 Dec 2025 16:19:14 -0800
Subject: [PATCH 04/27] rebase
---
mlir/include/mlir/Bindings/Python/Globals.h | 44 -----------------
mlir/include/mlir/Bindings/Python/IRCore.h | 14 +++---
mlir/lib/Bindings/Python/Globals.cpp | 3 --
mlir/lib/Bindings/Python/IRAttributes.cpp | 8 ----
mlir/lib/Bindings/Python/IRCore.cpp | 1 +
mlir/lib/Bindings/Python/MainModule.cpp | 53 +++++++++++++++++++++
mlir/python/CMakeLists.txt | 9 ++--
7 files changed, 65 insertions(+), 67 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 19ffe8164d727..da06bbfaed479 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -201,50 +201,6 @@ class PyGlobals {
TracebackLoc tracebackLoc;
TypeIDAllocator typeIDAllocator;
};
-
-/// Wrapper for the global LLVM debugging flag.
-struct PyGlobalDebugFlag {
- static void set(nanobind::object &o, bool enable) {
- nanobind::ft_lock_guard lock(mutex);
- mlirEnableGlobalDebug(enable);
- }
-
- static bool get(const nanobind::object &) {
- nanobind::ft_lock_guard lock(mutex);
- return mlirIsGlobalDebugEnabled();
- }
-
- static void bind(nanobind::module_ &m) {
- // Debug flags.
- nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
- .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
- .def_static(
- "set_types",
- [](const std::string &type) {
- nanobind::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugType(type.c_str());
- },
- nanobind::arg("types"),
- "Sets specific debug types to be produced by LLVM.")
- .def_static(
- "set_types",
- [](const std::vector<std::string> &types) {
- std::vector<const char *> pointers;
- pointers.reserve(types.size());
- for (const std::string &str : types)
- pointers.push_back(str.c_str());
- nanobind::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
- },
- nanobind::arg("types"),
- "Sets multiple specific debug types to be produced by LLVM.");
- }
-
-private:
- static nanobind::ft_mutex mutex;
-};
-
} // namespace python
} // namespace mlir
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 66a6272eaaf68..e82ee8da20fe5 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1350,12 +1350,12 @@ void registerMLIRErrorInCore();
/// Helper for creating an @classmethod.
template <class Func, typename... Args>
-static nanobind::object classmethod(Func f, Args... args) {
+nanobind::object classmethod(Func f, Args... args) {
nanobind::object cf = nanobind::cpp_function(f, args...);
return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
}
-static nanobind::object
+inline nanobind::object
createCustomDialectWrapper(const std::string &dialectNamespace,
nanobind::object dialectDescriptor) {
auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
@@ -1368,21 +1368,21 @@ createCustomDialectWrapper(const std::string &dialectNamespace,
return (*dialectClass)(std::move(dialectDescriptor));
}
-static MlirStringRef toMlirStringRef(const std::string &s) {
+inline MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
-static MlirStringRef toMlirStringRef(std::string_view s) {
+inline MlirStringRef toMlirStringRef(std::string_view s) {
return mlirStringRefCreate(s.data(), s.size());
}
-static MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
+inline MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
}
/// Create a block, using the current location context if no locations are
/// specified.
-static MlirBlock
+inline MlirBlock
createBlock(const nanobind::sequence &pyArgTypes,
const std::optional<nanobind::sequence> &pyArgLocs) {
SmallVector<MlirType> argTypes;
@@ -1871,7 +1871,7 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
/// Returns the list of types of the values held by container.
template <typename Container>
-static std::vector<nanobind::typed<nanobind::object, PyType>>
+std::vector<nanobind::typed<nanobind::object, PyType>>
getValueTypes(Container &container, PyMlirContextRef &context) {
std::vector<nanobind::typed<nanobind::object, PyType>> result;
result.reserve(container.size());
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index bc6b210426221..97a2df37a729b 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -267,7 +267,4 @@ bool PyGlobals::TracebackLoc::isUserTracebackFilename(
}
return isUserTracebackFilenameCache[file];
}
-
-nanobind::ft_mutex PyGlobalDebugFlag::mutex;
-
} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 4323374a5d5b7..e39eabdb136b8 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -228,14 +228,6 @@ struct nb_format_descriptor<double> {
static const char *format() { return "d"; }
};
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(const nb::bytes &s) {
- return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
-}
-
class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index ea1e62b8165ad..fc8743599508d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -26,6 +26,7 @@
#include <iostream>
#include <optional>
+#include <typeinfo>
namespace nb = nanobind;
using namespace nb::literals;
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 643851fcaf046..56dd4e0892655 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -185,6 +185,51 @@ maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
return {ref, mlirLoc};
}
+
+/// Wrapper for the global LLVM debugging flag.
+struct PyGlobalDebugFlag {
+ static void set(nanobind::object &o, bool enable) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+ }
+
+ static bool get(const nanobind::object &) {
+ nanobind::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+ }
+
+ static void bind(nanobind::module_ &m) {
+ // Debug flags.
+ nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+ .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
+ .def_static(
+ "set_types",
+ [](const std::string &type) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugType(type.c_str());
+ },
+ nanobind::arg("types"),
+ "Sets specific debug types to be produced by LLVM.")
+ .def_static(
+ "set_types",
+ [](const std::vector<std::string> &types) {
+ std::vector<const char *> pointers;
+ pointers.reserve(types.size());
+ for (const std::string &str : types)
+ pointers.push_back(str.c_str());
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
+ },
+ nanobind::arg("types"),
+ "Sets multiple specific debug types to be produced by LLVM.");
+ }
+
+private:
+ static nanobind::ft_mutex mutex;
+};
+
+nanobind::ft_mutex PyGlobalDebugFlag::mutex;
} // namespace
//------------------------------------------------------------------------------
@@ -1241,6 +1286,14 @@ static void populateIRCore(nb::module_ &m) {
return PyOpSuccessors(self.getOperation().getRef());
},
"Returns the list of Operation successors.")
+ .def(
+ "replace_uses_of_with",
+ [](PyOperation &self, PyValue &of, PyValue &with) {
+ mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
+ },
+ "of"_a, "with_"_a,
+ "Replaces uses of the 'of' value with the 'with' value inside the "
+ "operation.")
.def("_set_invalid", &PyOperation::setInvalid,
"Invalidate the operation.");
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 9286cead1a5c7..a32c85cf10359 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -1006,12 +1006,11 @@ add_mlir_library(MLIRPythonSupport
${NB_LIBRARY_TARGET_NAME}
MLIRPythonCAPI
)
-nanobind_link_options(MLIRPythonSupport)
-get_target_property(_current_link_options MLIRPythonSupport LINK_OPTIONS)
-if(_current_link_options)
- string(REPLACE "LINKER:-z,defs" "" _modified_link_options "${_current_link_options}")
- set_property(TARGET MLIRPythonSupport PROPERTY LINK_OPTIONS "${_modified_link_options}")
+if((CMAKE_SYSTEM_NAME STREQUAL "Linux") AND (NOT LLVM_USE_SANITIZER))
+ target_link_options(MLIRPythonSupport PRIVATE "-Wl,-z,undefs")
+ target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "-Wl,-z,undefs")
endif()
+nanobind_link_options(MLIRPythonSupport)
set_target_properties(MLIRPythonSupport PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
>From 270162e2244e56e55f8135a7248c2531b1d2b969 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 16 Dec 2025 15:13:36 -0800
Subject: [PATCH 05/27] fix after rebase
---
mlir/include/mlir/Bindings/Python/IRCore.h | 4 ++--
mlir/lib/Bindings/Python/MainModule.cpp | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index e82ee8da20fe5..649dfce22ad35 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1852,12 +1852,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
static void bindDerived(ClassTy &c) {
c.def_prop_ro(
"owner",
- [](PyOpResult &self) -> nanobind::typed<nanobind::object, PyOperation> {
+ [](PyOpResult &self) -> nanobind::typed<nanobind::object, PyOpView> {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in "
"the IR");
- return self.getParentOperation().getObject();
+ return self.getParentOperation()->createOpView();
},
"Returns the operation that produces this result.");
c.def_prop_ro(
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 56dd4e0892655..f72775cc0b83a 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -1993,7 +1993,7 @@ static void populateIRCore(nb::module_ &m) {
kDumpDocstring)
.def_prop_ro(
"owner",
- [](PyValue &self) -> nb::object {
+ [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
MlirValue v = self.get();
if (mlirValueIsAOpResult(v)) {
assert(mlirOperationEqual(self.getParentOperation()->get(),
@@ -2001,7 +2001,7 @@ static void populateIRCore(nb::module_ &m) {
"expected the owner of the value in Python to match "
"that in "
"the IR");
- return self.getParentOperation().getObject();
+ return self.getParentOperation()->createOpView();
}
if (mlirValueIsABlockArgument(v)) {
>From b04263a58e6ecd374022def8adbb61f82b906adb Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 19 Dec 2025 10:03:47 -0800
Subject: [PATCH 06/27] try fix windows badcast
---
mlir/include/mlir/Bindings/Python/Globals.h | 5 +----
mlir/lib/Bindings/Python/Globals.cpp | 5 +++++
mlir/python/CMakeLists.txt | 18 +++++++++---------
mlir/test/python/dialects/python_test.py | 12 +++---------
4 files changed, 18 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index da06bbfaed479..4584828868451 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -38,10 +38,7 @@ class PyGlobals {
~PyGlobals();
/// Most code should get the globals via this static accessor.
- static PyGlobals &get() {
- assert(instance && "PyGlobals is null");
- return *instance;
- }
+ static PyGlobals &get();
/// Get and set the list of parent modules to search for dialect
/// implementation classes.
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index 97a2df37a729b..ecac571a132f6 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -39,6 +39,11 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }
+PyGlobals &PyGlobals::get() {
+ assert(instance && "PyGlobals is null");
+ return *instance;
+}
+
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
{
nb::ft_lock_guard lock(mutex);
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index a32c85cf10359..a8bbd15124df5 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -994,17 +994,16 @@ endif()
get_property(NB_LIBRARY_TARGET_NAME TARGET MLIRPythonModules.extension._mlir.dso PROPERTY LINK_LIBRARIES)
list(GET NB_LIBRARY_TARGET_NAME 0 NB_LIBRARY_TARGET_NAME)
add_mlir_library_install(${NB_LIBRARY_TARGET_NAME})
-add_mlir_library(MLIRPythonSupport
+add_library(MLIRPythonSupport SHARED
${PYTHON_SOURCE_DIR}/IRCore.cpp
${PYTHON_SOURCE_DIR}/Globals.cpp
- EXCLUDE_FROM_LIBMLIR
- SHARED
- LINK_COMPONENTS
- Support
- LINK_LIBS
- Python::Module
- ${NB_LIBRARY_TARGET_NAME}
- MLIRPythonCAPI
+)
+target_link_libraries(MLIRPythonSupport PRIVATE
+ LLVMSupport
+ Python::Module
+ ${NB_LIBRARY_TARGET_NAME}
+ MLIRPythonCAPI
+
)
if((CMAKE_SYSTEM_NAME STREQUAL "Linux") AND (NOT LLVM_USE_SANITIZER))
target_link_options(MLIRPythonSupport PRIVATE "-Wl,-z,undefs")
@@ -1028,6 +1027,7 @@ set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
set(eh_rtti_enable)
if(MSVC)
set(eh_rtti_enable /EHsc /GR)
+ set_property(TARGET MLIRPythonSupport PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
elseif(LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
set(eh_rtti_enable -frtti -fexceptions)
endif()
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 7bba20931e675..e50c8722f8959 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -613,12 +613,6 @@ def testCustomType():
b = TestType(a)
# Instance custom types should have typeids
assert isinstance(b.typeid, TypeID)
- # Subclasses of ir.Type should not have a static_typeid
- # CHECK: 'TestType' object has no attribute 'static_typeid'
- try:
- b.static_typeid
- except AttributeError as e:
- print(e)
i8 = IntegerType.get_signless(8)
try:
@@ -633,9 +627,9 @@ def testCustomType():
try:
TestType(42)
except TypeError as e:
- assert "Expected an MLIR object (got 42)" in str(e)
- except ValueError as e:
- assert "Cannot cast type to TestType (from 42)" in str(e)
+ assert "__init__(): incompatible function arguments. The following argument types are supported" in str(e)
+ assert "__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None" in str(e)
+ assert "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int" in str(e)
else:
raise
>From 81bf65bf3c3640fe0a38afa8054d5fbf0d06fbd3 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 19 Dec 2025 14:04:46 -0800
Subject: [PATCH 07/27] port mlir_attribute_subclass
---
mlir/test/python/dialects/python_test.py | 6 ++--
.../python/lib/PythonTestModuleNanobind.cpp | 34 ++++++++++++-------
2 files changed, 24 insertions(+), 16 deletions(-)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index e50c8722f8959..0ba56b7922ff5 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -586,9 +586,9 @@ def testCustomAttribute():
try:
TestAttr(42)
except TypeError as e:
- assert "Expected an MLIR object (got 42)" in str(e)
- except ValueError as e:
- assert "Cannot cast attribute to TestAttr (from 42)" in str(e)
+ assert "__init__(): incompatible function arguments. The following argument types are supported" in str(e)
+ assert "__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None" in str(e)
+ assert "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int" in str(e)
else:
raise
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index e53f1ab3b4d3f..c8b95e2316778 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -45,6 +45,26 @@ struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
}
};
+class PyTestAttr : public mlir::python::PyConcreteAttribute<PyTestAttr> {
+public:
+ static constexpr IsAFunctionTy isaFunction =
+ mlirAttributeIsAPythonTestTestAttribute;
+ static constexpr const char *pyClassName = "TestAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirPythonTestTestAttributeGetTypeID;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](mlir::python::DefaultingPyMlirContext context) {
+ return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
+ context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
@@ -84,19 +104,7 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None"));
// clang-format on
- mlir_attribute_subclass(m, "TestAttr",
- mlirAttributeIsAPythonTestTestAttribute,
- mlirPythonTestTestAttributeGetTypeID)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPythonTestTestAttributeGet(ctx));
- },
- // clang-format off
- nb::sig("def get(cls: object, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("context").none() = nb::none());
-
+ PyTestAttr::bind(m);
PyTestType::bind(m);
auto typeCls =
>From b27d9e5fca8a82c49fd1a72f3c5b74f7e225f172 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 19 Dec 2025 14:10:11 -0800
Subject: [PATCH 08/27] format
---
mlir/test/python/dialects/python_test.py | 30 +++++++++++++++++++-----
1 file changed, 24 insertions(+), 6 deletions(-)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 0ba56b7922ff5..9c0966d2d8798 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -586,9 +586,18 @@ def testCustomAttribute():
try:
TestAttr(42)
except TypeError as e:
- assert "__init__(): incompatible function arguments. The following argument types are supported" in str(e)
- assert "__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None" in str(e)
- assert "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int" in str(e)
+ assert (
+ "__init__(): incompatible function arguments. The following argument types are supported"
+ in str(e)
+ )
+ assert (
+ "__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None"
+ in str(e)
+ )
+ assert (
+ "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int"
+ in str(e)
+ )
else:
raise
@@ -627,9 +636,18 @@ def testCustomType():
try:
TestType(42)
except TypeError as e:
- assert "__init__(): incompatible function arguments. The following argument types are supported" in str(e)
- assert "__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None" in str(e)
- assert "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int" in str(e)
+ assert (
+ "__init__(): incompatible function arguments. The following argument types are supported"
+ in str(e)
+ )
+ assert (
+ "__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None"
+ in str(e)
+ )
+ assert (
+ "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int"
+ in str(e)
+ )
else:
raise
>From a1f3e174c926339120702d22c6b7f4607d47e679 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 19 Dec 2025 18:40:39 -0800
Subject: [PATCH 09/27] massage cmake
---
mlir/cmake/modules/AddMLIRPython.cmake | 158 +++++++++++++++++++-----
mlir/examples/standalone/CMakeLists.txt | 3 +
mlir/python/CMakeLists.txt | 71 ++---------
3 files changed, 140 insertions(+), 92 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 47b7a58fc821d..1133eff4393d0 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -228,14 +228,19 @@ endfunction()
# aggregate dylib that is linked against.
function(declare_mlir_python_extension name)
cmake_parse_arguments(ARG
- ""
- "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT"
+ "SUPPORT_LIB"
+ "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;SOURCES_TYPE"
"SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS"
${ARGN})
if(NOT ARG_ROOT_DIR)
set(ARG_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
endif()
+ if(ARG_SUPPORT_LIB)
+ set(SOURCES_TYPE "support")
+ else()
+ set(SOURCES_TYPE "extension")
+ endif()
set(_install_destination "src/python/${name}")
add_library(${name} INTERFACE)
@@ -243,7 +248,7 @@ function(declare_mlir_python_extension name)
# Yes: Leading-lowercase property names are load bearing and the recommended
# way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261
EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS"
- mlir_python_SOURCES_TYPE extension
+ mlir_python_SOURCES_TYPE "${SOURCES_TYPE}"
mlir_python_EXTENSION_MODULE_NAME "${ARG_MODULE_NAME}"
mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
mlir_python_DEPENDS ""
@@ -297,6 +302,39 @@ function(_mlir_python_install_sources name source_root_dir destination)
)
endfunction()
+function(build_nanobind_lib)
+ cmake_parse_arguments(ARG
+ ""
+ "INSTALL_COMPONENT;INSTALL_DESTINATION;OUTPUT_DIRECTORY"
+ ""
+ ${ARGN})
+
+ if (NB_ABI MATCHES "[0-9]t")
+ set(_ft "-ft")
+ endif()
+ # nanobind does a string match on the suffix to figure out whether to build
+ # the lib with free threading...
+ set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
+ set(NB_LIBRARY_TARGET_NAME "${NB_LIBRARY_TARGET_NAME}" PARENT_SCOPE)
+ nanobind_build_library(${NB_LIBRARY_TARGET_NAME} AS_SYSINCLUDE)
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "-Wl,-z,undefs")
+ endif()
+ set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
+ LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ # Needed for windows (and don't hurt others).
+ RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ )
+ mlir_python_setup_extension_rpath(${NB_LIBRARY_TARGET_NAME})
+ install(TARGETS ${NB_LIBRARY_TARGET_NAME}
+ COMPONENT ${ARG_INSTALL_COMPONENT}
+ LIBRARY DESTINATION "${ARG_INSTALL_DESTINATION}"
+ RUNTIME DESTINATION "${ARG_INSTALL_DESTINATION}"
+ )
+endfunction()
+
# Function: add_mlir_python_modules
# Adds python modules to a project, building them from a list of declared
# source groupings (see declare_mlir_python_sources and
@@ -318,8 +356,16 @@ function(add_mlir_python_modules name)
"ROOT_PREFIX;INSTALL_PREFIX"
"COMMON_CAPI_LINK_LIBS;DECLARED_SOURCES"
${ARGN})
+
+ # This call sets NB_LIBRARY_TARGET_NAME.
+ build_nanobind_lib(
+ INSTALL_COMPONENT ${name}
+ INSTALL_DESTINATION "${ARG_INSTALL_PREFIX}/_mlir_libs"
+ OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ )
+
# Helper to process an individual target.
- function(_process_target modules_target sources_target)
+ function(_process_target modules_target sources_target support_libs)
get_target_property(_source_type ${sources_target} mlir_python_SOURCES_TYPE)
if(_source_type STREQUAL "pure")
@@ -337,16 +383,19 @@ function(add_mlir_python_modules name)
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
# Transform relative source to based on root dir.
set(_extension_target "${modules_target}.extension.${_module_name}.dso")
- add_mlir_python_extension(${_extension_target} "${_module_name}"
+ add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
INSTALL_COMPONENT ${modules_target}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
LINK_LIBS PRIVATE
${sources_target}
${ARG_COMMON_CAPI_LINK_LIBS}
+ ${support_libs}
)
add_dependencies(${modules_target} ${_extension_target})
mlir_python_setup_extension_rpath(${_extension_target})
+ elseif(_source_type STREQUAL "support")
+ # do nothing because already built
else()
message(SEND_ERROR "Unrecognized source type '${_source_type}' for python source target ${sources_target}")
return()
@@ -356,8 +405,34 @@ function(add_mlir_python_modules name)
# Build the modules target.
add_custom_target(${name} ALL)
_flatten_mlir_python_targets(_flat_targets ${ARG_DECLARED_SOURCES})
+
+ # Build all support libs first.
+ set(_mlir_python_support_libs)
+ foreach(sources_target ${_flat_targets})
+ get_target_property(_source_type ${sources_target} mlir_python_SOURCES_TYPE)
+ if(_source_type STREQUAL "support")
+ get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
+ set(_extension_target "${name}.extension.${_module_name}.dso")
+ add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
+ INSTALL_COMPONENT ${name}
+ INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
+ OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ SUPPORT_LIB
+ LINK_LIBS PRIVATE
+ LLVMSupport
+ Python::Module
+ ${sources_target}
+ ${ARG_COMMON_CAPI_LINK_LIBS}
+ )
+ add_dependencies(${name} ${_extension_target})
+ mlir_python_setup_extension_rpath(${_extension_target})
+ list(APPEND _mlir_python_support_libs "${_extension_target}")
+ endif()
+ endforeach()
+
+ # Build extensions.
foreach(sources_target ${_flat_targets})
- _process_target(${name} ${sources_target})
+ _process_target(${name} ${sources_target} ${_mlir_python_support_libs})
endforeach()
# Create an install target.
@@ -741,9 +816,9 @@ endfunction()
################################################################################
# Build python extension
################################################################################
-function(add_mlir_python_extension libname extname)
+function(add_mlir_python_extension libname extname nb_library_target_name)
cmake_parse_arguments(ARG
- ""
+ "SUPPORT_LIB"
"INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY"
"SOURCES;LINK_LIBS"
${ARGN})
@@ -760,41 +835,57 @@ function(add_mlir_python_extension libname extname)
set(eh_rtti_enable -frtti -fexceptions)
endif ()
- nanobind_add_module(${libname}
- NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
- FREE_THREADED
- NB_SHARED
- ${ARG_SOURCES}
- )
+ if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
+ endif()
+
+ if(ARG_SUPPORT_LIB)
+ add_library(${libname} SHARED ${ARG_SOURCES})
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ target_link_options(${libname} PRIVATE "-Wl,-z,undefs")
+ endif()
+ nanobind_link_options(${libname})
+ target_compile_definitions(${libname} PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
+ if (MSVC)
+ set_property(TARGET ${libname} PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
+ endif ()
+ else()
+ nanobind_add_module(${libname}
+ NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ FREE_THREADED
+ NB_SHARED
+ ${ARG_SOURCES}
+ )
+ endif()
+ target_link_libraries(${libname} PRIVATE ${nb_library_target_name})
if (NOT MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES
AND (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL))
# Avoid some warnings from upstream nanobind.
# If a superproject set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES, let
# the super project handle compile options as it wishes.
- get_property(NB_LIBRARY_TARGET_NAME TARGET ${libname} PROPERTY LINK_LIBRARIES)
- target_compile_options(${NB_LIBRARY_TARGET_NAME}
+ target_compile_options(${nb_library_target_name}
PRIVATE
-Wno-c++98-compat-extra-semi
- -Wno-cast-qual
- -Wno-covered-switch-default
- -Wno-deprecated-literal-operator
- -Wno-nested-anon-types
- -Wno-unused-parameter
- -Wno-zero-length-array
- -Wno-missing-field-initializers
+ -Wno-cast-qual
+ -Wno-covered-switch-default
+ -Wno-deprecated-literal-operator
+ -Wno-nested-anon-types
+ -Wno-unused-parameter
+ -Wno-zero-length-array
+ -Wno-missing-field-initializers
${eh_rtti_enable})
target_compile_options(${libname}
PRIVATE
-Wno-c++98-compat-extra-semi
- -Wno-cast-qual
- -Wno-covered-switch-default
- -Wno-deprecated-literal-operator
- -Wno-nested-anon-types
- -Wno-unused-parameter
- -Wno-zero-length-array
- -Wno-missing-field-initializers
+ -Wno-cast-qual
+ -Wno-covered-switch-default
+ -Wno-deprecated-literal-operator
+ -Wno-nested-anon-types
+ -Wno-unused-parameter
+ -Wno-zero-length-array
+ -Wno-missing-field-initializers
${eh_rtti_enable})
endif()
@@ -813,11 +904,16 @@ function(add_mlir_python_extension libname extname)
target_compile_options(${libname} PRIVATE ${eh_rtti_enable})
# Configure the output to match python expectations.
+ if (ARG_SUPPORT_LIB)
+ set(_no_soname OFF)
+ else ()
+ set(_no_soname ON)
+ endif ()
set_target_properties(
${libname} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${ARG_OUTPUT_DIRECTORY}
OUTPUT_NAME "${extname}"
- NO_SONAME ON
+ NO_SONAME ${_no_soname}
)
if(WIN32)
diff --git a/mlir/examples/standalone/CMakeLists.txt b/mlir/examples/standalone/CMakeLists.txt
index c6c49fde12d2e..323716c4baf3a 100644
--- a/mlir/examples/standalone/CMakeLists.txt
+++ b/mlir/examples/standalone/CMakeLists.txt
@@ -66,6 +66,9 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
if(NOT MLIR_PYTHON_PACKAGE_PREFIX)
set(MLIR_PYTHON_PACKAGE_PREFIX "mlir_standalone" CACHE STRING "" FORCE)
endif()
+ if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir_standalone" CACHE STRING "" FORCE)
+ endif()
if(NOT MLIR_BINDINGS_PYTHON_INSTALL_PREFIX)
set(MLIR_BINDINGS_PYTHON_INSTALL_PREFIX "python_packages/standalone/${MLIR_PYTHON_PACKAGE_PREFIX}" CACHE STRING "" FORCE)
endif()
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index a8bbd15124df5..b22d2ec75b3ba 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -784,7 +784,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.AMDGPU.Nanobind
MODULE_NAME _mlirDialectsAMDGPU
ADD_TO_PARENT MLIRPythonSources.Dialects.amdgpu
ROOT_DIR "${PYTHON_SOURCE_DIR}"
- PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectAMDGPU.cpp
PRIVATE_LINK_LIBS
@@ -841,6 +840,16 @@ if(MLIR_INCLUDE_TESTS)
)
endif()
+declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
+ SUPPORT_LIB
+ MODULE_NAME MLIRPythonSupport
+ ADD_TO_PARENT MLIRPythonSources.Core
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ SOURCES
+ IRCore.cpp
+ Globals.cpp
+)
+
################################################################################
# Common CAPI dependency DSO.
# All python extensions must link through one DSO which exports the CAPI, and
@@ -990,63 +999,3 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}")
endif()
endif()
-
-get_property(NB_LIBRARY_TARGET_NAME TARGET MLIRPythonModules.extension._mlir.dso PROPERTY LINK_LIBRARIES)
-list(GET NB_LIBRARY_TARGET_NAME 0 NB_LIBRARY_TARGET_NAME)
-add_mlir_library_install(${NB_LIBRARY_TARGET_NAME})
-add_library(MLIRPythonSupport SHARED
- ${PYTHON_SOURCE_DIR}/IRCore.cpp
- ${PYTHON_SOURCE_DIR}/Globals.cpp
-)
-target_link_libraries(MLIRPythonSupport PRIVATE
- LLVMSupport
- Python::Module
- ${NB_LIBRARY_TARGET_NAME}
- MLIRPythonCAPI
-
-)
-if((CMAKE_SYSTEM_NAME STREQUAL "Linux") AND (NOT LLVM_USE_SANITIZER))
- target_link_options(MLIRPythonSupport PRIVATE "-Wl,-z,undefs")
- target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "-Wl,-z,undefs")
-endif()
-nanobind_link_options(MLIRPythonSupport)
-set_target_properties(MLIRPythonSupport PROPERTIES
- LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- # Needed for windows (and doesn't hurt others).
- RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
-)
-set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
- LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- # Needed for windows (and doesn't hurt others).
- RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
-)
-set(eh_rtti_enable)
-if(MSVC)
- set(eh_rtti_enable /EHsc /GR)
- set_property(TARGET MLIRPythonSupport PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
-elseif(LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
- set(eh_rtti_enable -frtti -fexceptions)
-endif()
-target_compile_options(MLIRPythonSupport PRIVATE ${eh_rtti_enable})
-if(APPLE)
- # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind
- # doesn't declare this API as undefined in its linker flags. So we need to declare it as such
- # for downstream users that do not do something like `-undefined dynamic_lookup`.
- # Same for the rest.
- target_link_options(MLIRPythonSupport PUBLIC
- "LINKER:-U,_PyClassMethod_New"
- "LINKER:-U,_PyCode_Addr2Location"
- "LINKER:-U,_PyFrame_GetLasti"
- )
-endif()
-target_link_libraries(
- MLIRPythonModules.extension._mlir.dso
- PUBLIC MLIRPythonSupport)
-target_link_libraries(
- MLIRPythonModules.extension._mlirPythonTestNanobind.dso
- PUBLIC MLIRPythonSupport)
-target_compile_definitions(MLIRPythonSupport PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
>From 0e350f34ccefd21e5ae515749242603f28a744a1 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sun, 21 Dec 2025 16:12:46 -0800
Subject: [PATCH 10/27] add standalone test/use of IRCore
---
.../include/Standalone-c/Dialects.h | 7 ++++++
.../examples/standalone/lib/CAPI/Dialects.cpp | 13 ++++++++++
.../python/StandaloneExtensionNanobind.cpp | 25 +++++++++++++++++++
.../standalone/test/python/smoketest.py | 4 +++
mlir/include/mlir/Bindings/Python/Globals.h | 1 -
5 files changed, 49 insertions(+), 1 deletion(-)
diff --git a/mlir/examples/standalone/include/Standalone-c/Dialects.h b/mlir/examples/standalone/include/Standalone-c/Dialects.h
index b3e47752ccc69..5aa9e004cb9fe 100644
--- a/mlir/examples/standalone/include/Standalone-c/Dialects.h
+++ b/mlir/examples/standalone/include/Standalone-c/Dialects.h
@@ -17,6 +17,13 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Standalone, standalone);
+MLIR_CAPI_EXPORTED MlirType mlirStandaloneCustomTypeGet(MlirContext ctx,
+ MlirStringRef value);
+
+MLIR_CAPI_EXPORTED bool mlirStandaloneTypeIsACustomType(MlirType t);
+
+MLIR_CAPI_EXPORTED MlirTypeID mlirStandaloneCustomTypeGetTypeID(void);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/examples/standalone/lib/CAPI/Dialects.cpp b/mlir/examples/standalone/lib/CAPI/Dialects.cpp
index 98006e81a3d26..4de55ba485490 100644
--- a/mlir/examples/standalone/lib/CAPI/Dialects.cpp
+++ b/mlir/examples/standalone/lib/CAPI/Dialects.cpp
@@ -9,7 +9,20 @@
#include "Standalone-c/Dialects.h"
#include "Standalone/StandaloneDialect.h"
+#include "Standalone/StandaloneTypes.h"
#include "mlir/CAPI/Registration.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Standalone, standalone,
mlir::standalone::StandaloneDialect)
+
+MlirType mlirStandaloneCustomTypeGet(MlirContext ctx, MlirStringRef value) {
+ return wrap(mlir::standalone::CustomType::get(unwrap(ctx), unwrap(value)));
+}
+
+bool mlirStandaloneTypeIsACustomType(MlirType t) {
+ return llvm::isa<mlir::standalone::CustomType>(unwrap(t));
+}
+
+MlirTypeID mlirStandaloneCustomTypeGetTypeID() {
+ return wrap(mlir::standalone::CustomType::getTypeID());
+}
diff --git a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
index 0ec6cdfa7994b..37737cd89ee1e 100644
--- a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
@@ -11,17 +11,42 @@
#include "Standalone-c/Dialects.h"
#include "mlir-c/Dialect/Arith.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
+struct PyCustomType : mlir::python::PyConcreteType<PyCustomType> {
+ static constexpr IsAFunctionTy isaFunction = mlirStandaloneTypeIsACustomType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirStandaloneCustomTypeGetTypeID;
+ static constexpr const char *pyClassName = "CustomType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &value,
+ mlir::python::DefaultingPyMlirContext context) {
+ return PyCustomType(
+ context->getRef(),
+ mlirStandaloneCustomTypeGet(
+ context.get()->get(),
+ mlirStringRefCreateFromCString(value.c_str())));
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none());
+ }
+};
+
NB_MODULE(_standaloneDialectsNanobind, m) {
//===--------------------------------------------------------------------===//
// standalone dialect
//===--------------------------------------------------------------------===//
auto standaloneM = m.def_submodule("standalone");
+ PyCustomType::bind(standaloneM);
+
standaloneM.def(
"register_dialects",
[](MlirContext context, bool load) {
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index f8819841fac45..9c0ada92551af 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -14,3 +14,7 @@
# CHECK: %[[C:.*]] = arith.constant 2 : i32
# CHECK: standalone.foo %[[C]] : i32
print(str(module))
+
+ custom_type = standalone_d.CustomType.get("foo")
+ # CHECK: !standalone.custom<"foo">
+ print(custom_type)
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 4584828868451..2184e7e2dc5ca 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -15,7 +15,6 @@
#include <unordered_set>
#include <vector>
-#include "mlir-c/Debug.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/NanobindUtils.h"
>From 5e83caff3a457fe70e7cfbac3e89c96953bcaa1c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 08:24:41 -0800
Subject: [PATCH 11/27] disable LTO by default
---
mlir/cmake/modules/AddMLIRPython.cmake | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 1133eff4393d0..6fd19fe6736a8 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -317,6 +317,14 @@ function(build_nanobind_lib)
set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(NB_LIBRARY_TARGET_NAME "${NB_LIBRARY_TARGET_NAME}" PARENT_SCOPE)
nanobind_build_library(${NB_LIBRARY_TARGET_NAME} AS_SYSINCLUDE)
+ # nanobind configures with LTO for shared build which doesn't work everywhere
+ # (see https://github.com/llvm/llvm-project/issues/139602).
+ if(NOT LLVM_ENABLE_LTO)
+ set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
+ INTERPROCEDURAL_OPTIMIZATION_RELEASE OFF
+ INTERPROCEDURAL_OPTIMIZATION_MINSIZEREL OFF
+ )
+ endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "-Wl,-z,undefs")
endif()
>From d744ace5b9a7ffb4555821ea971e8d0ce03144d3 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 09:27:36 -0800
Subject: [PATCH 12/27] restore DISABLE_INSTALL
---
mlir/cmake/modules/AddMLIRPython.cmake | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 6fd19fe6736a8..92b10c935fbbe 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -684,6 +684,7 @@ function(add_mlir_python_common_capi_library name)
# Generate the aggregate .so that everything depends on.
add_mlir_aggregate(${name}
SHARED
+ DISABLE_INSTALL
EMBED_LIBS ${_embed_libs}
)
>From fb434340afe692c583bee8ea5420536ee0d0c288 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 11:09:32 -0800
Subject: [PATCH 13/27] set VISIBILITY_INLINES_HIDDEN for libMLIRPYthonSupport
---
mlir/cmake/modules/AddMLIRPython.cmake | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 92b10c935fbbe..0501e0b5e51fe 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -855,6 +855,11 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
endif()
nanobind_link_options(${libname})
target_compile_definitions(${libname} PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
+ set_target_properties(${libname} PROPERTIES
+ VISIBILITY_INLINES_HIDDEN OFF
+ C_VISIBILITY_PRESET default
+ CXX_VISIBILITY_PRESET default
+ )
if (MSVC)
set_property(TARGET ${libname} PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif ()
>From 7d6349d559870d96cf0076d84c63fed17ca63e06 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 12:20:34 -0800
Subject: [PATCH 14/27] try MLIR_PYTHON_API_EXPORTED
---
mlir/cmake/modules/AddMLIRPython.cmake | 9 +-
mlir/include/mlir-c/Support.h | 2 +
mlir/include/mlir/Bindings/Python/Globals.h | 4 +-
mlir/include/mlir/Bindings/Python/IRCore.h | 121 +++++++++++---------
mlir/include/mlir/Bindings/Python/IRTypes.h | 3 +-
5 files changed, 75 insertions(+), 64 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 0501e0b5e51fe..8ec9304421b54 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -854,11 +854,10 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
target_link_options(${libname} PRIVATE "-Wl,-z,undefs")
endif()
nanobind_link_options(${libname})
- target_compile_definitions(${libname} PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
- set_target_properties(${libname} PROPERTIES
- VISIBILITY_INLINES_HIDDEN OFF
- C_VISIBILITY_PRESET default
- CXX_VISIBILITY_PRESET default
+ target_compile_definitions(${libname}
+ PRIVATE
+ NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ MLIR_CAPI_BUILDING_LIBRARY=1
)
if (MSVC)
set_property(TARGET ${libname} PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h
index 78fc94f93439e..6abd8894227c3 100644
--- a/mlir/include/mlir-c/Support.h
+++ b/mlir/include/mlir-c/Support.h
@@ -46,6 +46,8 @@
#define MLIR_CAPI_EXPORTED __attribute__((visibility("default")))
#endif
+#define MLIR_PYTHON_API_EXPORTED MLIR_CAPI_EXPORTED
+
#ifdef __cplusplus
extern "C" {
#endif
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 2184e7e2dc5ca..112c7b9b0547f 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -31,7 +31,7 @@ namespace python {
/// Globals that are always accessible once the extension has been initialized.
/// Methods of this class are thread-safe.
-class PyGlobals {
+class MLIR_PYTHON_API_EXPORTED PyGlobals {
public:
PyGlobals();
~PyGlobals();
@@ -117,7 +117,7 @@ class PyGlobals {
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);
- class TracebackLoc {
+ class MLIR_PYTHON_API_EXPORTED TracebackLoc {
public:
bool locTracebacksEnabled();
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 649dfce22ad35..ceedeb691eb58 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -52,7 +52,7 @@ class PyValue;
/// Template for a reference to a concrete type which captures a python
/// reference to its underlying python object.
template <typename T>
-class PyObjectRef {
+class MLIR_PYTHON_API_EXPORTED PyObjectRef {
public:
PyObjectRef(T *referrent, nanobind::object object)
: referrent(referrent), object(std::move(object)) {
@@ -111,7 +111,7 @@ class PyObjectRef {
/// Context. Pushing a Context will not modify the Location or InsertionPoint
/// unless if they are from a different context, in which case, they are
/// cleared.
-class PyThreadContextEntry {
+class MLIR_PYTHON_API_EXPORTED PyThreadContextEntry {
public:
enum class FrameKind {
Context,
@@ -167,7 +167,7 @@ class PyThreadContextEntry {
/// Wrapper around MlirLlvmThreadPool
/// Python object owns the C++ thread pool
-class PyThreadPool {
+class MLIR_PYTHON_API_EXPORTED PyThreadPool {
public:
PyThreadPool() {
ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
@@ -190,7 +190,7 @@ class PyThreadPool {
/// Wrapper around MlirContext.
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
-class PyMlirContext {
+class MLIR_PYTHON_API_EXPORTED PyMlirContext {
public:
PyMlirContext() = delete;
PyMlirContext(MlirContext context);
@@ -271,7 +271,7 @@ class PyMlirContext {
/// Used in function arguments when None should resolve to the current context
/// manager set instance.
-class DefaultingPyMlirContext
+class MLIR_PYTHON_API_EXPORTED DefaultingPyMlirContext
: public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
public:
using Defaulting::Defaulting;
@@ -283,7 +283,7 @@ class DefaultingPyMlirContext
/// MlirContext. The lifetime of the context will extend at least to the
/// lifetime of these instances.
/// Immutable objects that depend on a context extend this directly.
-class BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED BaseContextObject {
public:
BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
assert(this->contextRef &&
@@ -298,7 +298,7 @@ class BaseContextObject {
};
/// Wrapper around an MlirLocation.
-class PyLocation : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyLocation : public BaseContextObject {
public:
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
: BaseContextObject(std::move(contextRef)), loc(loc) {}
@@ -329,7 +329,7 @@ class PyLocation : public BaseContextObject {
/// are only valid for the duration of a diagnostic callback and attempting
/// to access them outside of that will raise an exception. This applies to
/// nested diagnostics (in the notes) as well.
-class PyDiagnostic {
+class MLIR_PYTHON_API_EXPORTED PyDiagnostic {
public:
PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
void invalidate();
@@ -379,7 +379,7 @@ class PyDiagnostic {
/// The object may remain live from a Python perspective for an arbitrary time
/// after detachment, but there is nothing the user can do with it (since there
/// is no way to attach an existing handler object).
-class PyDiagnosticHandler {
+class MLIR_PYTHON_API_EXPORTED PyDiagnosticHandler {
public:
PyDiagnosticHandler(MlirContext context, nanobind::object callback);
~PyDiagnosticHandler();
@@ -407,7 +407,7 @@ class PyDiagnosticHandler {
/// RAII object that captures any error diagnostics emitted to the provided
/// context.
-struct PyMlirContext::ErrorCapture {
+struct MLIR_PYTHON_API_EXPORTED PyMlirContext::ErrorCapture {
ErrorCapture(PyMlirContextRef ctx)
: ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler(
ctx->get(), handler, /*userData=*/this,
@@ -434,7 +434,7 @@ struct PyMlirContext::ErrorCapture {
/// plugins which extend dialect functionality through extension python code.
/// This should be seen as the "low-level" object and `Dialect` as the
/// high-level, user facing object.
-class PyDialectDescriptor : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyDialectDescriptor : public BaseContextObject {
public:
PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
: BaseContextObject(std::move(contextRef)), dialect(dialect) {}
@@ -447,7 +447,7 @@ class PyDialectDescriptor : public BaseContextObject {
/// User-level object for accessing dialects with dotted syntax such as:
/// ctx.dialect.std
-class PyDialects : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyDialects : public BaseContextObject {
public:
PyDialects(PyMlirContextRef contextRef)
: BaseContextObject(std::move(contextRef)) {}
@@ -458,7 +458,7 @@ class PyDialects : public BaseContextObject {
/// User-level dialect object. For dialects that have a registered extension,
/// this will be the base class of the extension dialect type. For un-extended,
/// objects of this type will be returned directly.
-class PyDialect {
+class MLIR_PYTHON_API_EXPORTED PyDialect {
public:
PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {}
@@ -471,7 +471,7 @@ class PyDialect {
/// Wrapper around an MlirDialectRegistry.
/// Upon construction, the Python wrapper takes ownership of the
/// underlying MlirDialectRegistry.
-class PyDialectRegistry {
+class MLIR_PYTHON_API_EXPORTED PyDialectRegistry {
public:
PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {}
PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {}
@@ -497,7 +497,7 @@ class PyDialectRegistry {
/// Used in function arguments when None should resolve to the current context
/// manager set instance.
-class DefaultingPyLocation
+class MLIR_PYTHON_API_EXPORTED DefaultingPyLocation
: public Defaulting<DefaultingPyLocation, PyLocation> {
public:
using Defaulting::Defaulting;
@@ -511,7 +511,7 @@ class DefaultingPyLocation
/// This is the top-level, user-owned object that contains regions/ops/blocks.
class PyModule;
using PyModuleRef = PyObjectRef<PyModule>;
-class PyModule : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyModule : public BaseContextObject {
public:
/// Returns a PyModule reference for the given MlirModule. This always returns
/// a new object.
@@ -551,7 +551,7 @@ class PyAsmState;
/// Base class for PyOperation and PyOpView which exposes the primary, user
/// visible methods for manipulating it.
-class PyOperationBase {
+class MLIR_PYTHON_API_EXPORTED PyOperationBase {
public:
virtual ~PyOperationBase() = default;
/// Implements the bound 'print' method and helps with others.
@@ -604,7 +604,8 @@ class PyOperationBase {
class PyOperation;
class PyOpView;
using PyOperationRef = PyObjectRef<PyOperation>;
-class PyOperation : public PyOperationBase, public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyOperation : public PyOperationBase,
+ public BaseContextObject {
public:
~PyOperation() override;
PyOperation &getOperation() override { return *this; }
@@ -722,7 +723,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// custom ODS-style operation classes. Since this class is subclass on the
/// python side, it must present an __init__ method that operates in pure
/// python types.
-class PyOpView : public PyOperationBase {
+class MLIR_PYTHON_API_EXPORTED PyOpView : public PyOperationBase {
public:
PyOpView(const nanobind::object &operationObject);
PyOperation &getOperation() override { return operation; }
@@ -758,7 +759,7 @@ class PyOpView : public PyOperationBase {
/// Wrapper around an MlirRegion.
/// Regions are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached regions.
-class PyRegion {
+class MLIR_PYTHON_API_EXPORTED PyRegion {
public:
PyRegion(PyOperationRef parentOperation, MlirRegion region)
: parentOperation(std::move(parentOperation)), region(region) {
@@ -777,7 +778,7 @@ class PyRegion {
};
/// Wrapper around an MlirAsmState.
-class PyAsmState {
+class MLIR_PYTHON_API_EXPORTED PyAsmState {
public:
PyAsmState(MlirValue value, bool useLocalScope) {
flags = mlirOpPrintingFlagsCreate();
@@ -812,7 +813,7 @@ class PyAsmState {
/// Wrapper around an MlirBlock.
/// Blocks are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached blocks.
-class PyBlock {
+class MLIR_PYTHON_API_EXPORTED PyBlock {
public:
PyBlock(PyOperationRef parentOperation, MlirBlock block)
: parentOperation(std::move(parentOperation)), block(block) {
@@ -836,7 +837,7 @@ class PyBlock {
/// Calls to insert() will insert a new operation before the
/// reference operation. If the reference operation is null, then appends to
/// the end of the block.
-class PyInsertionPoint {
+class MLIR_PYTHON_API_EXPORTED PyInsertionPoint {
public:
/// Creates an insertion point positioned after the last operation in the
/// block, but still inside the block.
@@ -877,7 +878,7 @@ class PyInsertionPoint {
};
/// Wrapper around the generic MlirType.
/// The lifetime of a type is bound by the PyContext that created it.
-class PyType : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyType : public BaseContextObject {
public:
PyType(PyMlirContextRef contextRef, MlirType type)
: BaseContextObject(std::move(contextRef)), type(type) {}
@@ -903,7 +904,7 @@ class PyType : public BaseContextObject {
/// A TypeID provides an efficient and unique identifier for a specific C++
/// type. This allows for a C++ type to be compared, hashed, and stored in an
/// opaque context. This class wraps around the generic MlirTypeID.
-class PyTypeID {
+class MLIR_PYTHON_API_EXPORTED PyTypeID {
public:
PyTypeID(MlirTypeID typeID) : typeID(typeID) {}
// Note, this tests whether the underlying TypeIDs are the same,
@@ -929,7 +930,7 @@ class PyTypeID {
/// concrete type class extends PyType); however, intermediate python-visible
/// base classes can be modeled by specifying a BaseTy.
template <typename DerivedTy, typename BaseTy = PyType>
-class PyConcreteType : public BaseTy {
+class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@@ -1007,7 +1008,7 @@ class PyConcreteType : public BaseTy {
/// Wrapper around the generic MlirAttribute.
/// The lifetime of a type is bound by the PyContext that created it.
-class PyAttribute : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyAttribute : public BaseContextObject {
public:
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
: BaseContextObject(std::move(contextRef)), attr(attr) {}
@@ -1033,7 +1034,7 @@ class PyAttribute : public BaseContextObject {
/// Represents a Python MlirNamedAttr, carrying an optional owned name.
/// TODO: Refactor this and the C-API to be based on an Identifier owned
/// by the context so as to avoid ownership issues here.
-class PyNamedAttribute {
+class MLIR_PYTHON_API_EXPORTED PyNamedAttribute {
public:
/// Constructs a PyNamedAttr that retains an owned name. This should be
/// used in any code that originates an MlirNamedAttribute from a python
@@ -1059,7 +1060,7 @@ class PyNamedAttribute {
/// concrete attribute class extends PyAttribute); however, intermediate
/// python-visible base classes can be modeled by specifying a BaseTy.
template <typename DerivedTy, typename BaseTy = PyAttribute>
-class PyConcreteAttribute : public BaseTy {
+class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@@ -1149,7 +1150,8 @@ class PyConcreteAttribute : public BaseTy {
static void bindDerived(ClassTy &m) {}
};
-class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
+class MLIR_PYTHON_API_EXPORTED PyStringAttribute
+ : public PyConcreteAttribute<PyStringAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
static constexpr const char *pyClassName = "StringAttr";
@@ -1166,7 +1168,7 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
/// value. For block argument values, this is the operation that contains the
/// block to which the value is an argument (blocks cannot be detached in Python
/// bindings so such operation always exists).
-class PyValue {
+class MLIR_PYTHON_API_EXPORTED PyValue {
public:
// The virtual here is "load bearing" in that it enables RTTI
// for PyConcreteValue CRTP classes that support maybeDownCast.
@@ -1196,7 +1198,7 @@ class PyValue {
};
/// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
-class PyAffineExpr : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyAffineExpr : public BaseContextObject {
public:
PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
: BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
@@ -1223,7 +1225,7 @@ class PyAffineExpr : public BaseContextObject {
MlirAffineExpr affineExpr;
};
-class PyAffineMap : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyAffineMap : public BaseContextObject {
public:
PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
: BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
@@ -1244,7 +1246,7 @@ class PyAffineMap : public BaseContextObject {
MlirAffineMap affineMap;
};
-class PyIntegerSet : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyIntegerSet : public BaseContextObject {
public:
PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
: BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
@@ -1265,7 +1267,7 @@ class PyIntegerSet : public BaseContextObject {
};
/// Bindings for MLIR symbol tables.
-class PySymbolTable {
+class MLIR_PYTHON_API_EXPORTED PySymbolTable {
public:
/// Constructs a symbol table for the given operation.
explicit PySymbolTable(PyOperationBase &operation);
@@ -1317,7 +1319,7 @@ class PySymbolTable {
/// Custom exception that allows access to error diagnostic information. This is
/// converted to the `ir.MLIRError` python exception when thrown.
-struct MLIRError {
+struct MLIR_PYTHON_API_EXPORTED MLIRError {
MLIRError(llvm::Twine message,
std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {})
: message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {}
@@ -1342,7 +1344,7 @@ inline void registerMLIRError() {
});
}
-void registerMLIRErrorInCore();
+MLIR_PYTHON_API_EXPORTED void registerMLIRErrorInCore();
//------------------------------------------------------------------------------
// Utilities.
@@ -1455,7 +1457,7 @@ inline nanobind::object PyBlock::getCapsule() {
// Collections.
//------------------------------------------------------------------------------
-class PyRegionIterator {
+class MLIR_PYTHON_API_EXPORTED PyRegionIterator {
public:
PyRegionIterator(PyOperationRef operation, int nextIndex)
: operation(std::move(operation)), nextIndex(nextIndex) {}
@@ -1486,7 +1488,8 @@ class PyRegionIterator {
/// Regions of an op are fixed length and indexed numerically so are represented
/// with a sequence-like container.
-class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
+class MLIR_PYTHON_API_EXPORTED PyRegionList
+ : public Sliceable<PyRegionList, PyRegion> {
public:
static constexpr const char *pyClassName = "RegionSequence";
@@ -1529,7 +1532,7 @@ class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
PyOperationRef operation;
};
-class PyBlockIterator {
+class MLIR_PYTHON_API_EXPORTED PyBlockIterator {
public:
PyBlockIterator(PyOperationRef operation, MlirBlock next)
: operation(std::move(operation)), next(next) {}
@@ -1563,7 +1566,7 @@ class PyBlockIterator {
/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
/// we present them as a more full-featured list-like container but optimize
/// it for forward iteration. Blocks are always owned by a region.
-class PyBlockList {
+class MLIR_PYTHON_API_EXPORTED PyBlockList {
public:
PyBlockList(PyOperationRef operation, MlirRegion region)
: operation(std::move(operation)), region(region) {}
@@ -1636,7 +1639,7 @@ class PyBlockList {
MlirRegion region;
};
-class PyOperationIterator {
+class MLIR_PYTHON_API_EXPORTED PyOperationIterator {
public:
PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
: parentOperation(std::move(parentOperation)), next(next) {}
@@ -1672,7 +1675,7 @@ class PyOperationIterator {
/// Python, we present them as a more full-featured list-like container but
/// optimize it for forward iteration. Iterable operations are always owned
/// by a block.
-class PyOperationList {
+class MLIR_PYTHON_API_EXPORTED PyOperationList {
public:
PyOperationList(PyOperationRef parentOperation, MlirBlock block)
: parentOperation(std::move(parentOperation)), block(block) {}
@@ -1729,7 +1732,7 @@ class PyOperationList {
MlirBlock block;
};
-class PyOpOperand {
+class MLIR_PYTHON_API_EXPORTED PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
@@ -1754,7 +1757,7 @@ class PyOpOperand {
MlirOpOperand opOperand;
};
-class PyOpOperandIterator {
+class MLIR_PYTHON_API_EXPORTED PyOpOperandIterator {
public:
PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
@@ -1785,7 +1788,7 @@ class PyOpOperandIterator {
/// castable from it. The value hierarchy is one level deep and is not supposed
/// to accommodate other levels unless core MLIR changes.
template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
+class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@@ -1843,7 +1846,7 @@ class PyConcreteValue : public PyValue {
};
/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
+class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue<PyOpResult> {
public:
static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
static constexpr const char *pyClassName = "OpResult";
@@ -1887,7 +1890,8 @@ getValueTypes(Container &container, PyMlirContextRef &context) {
/// elements, random access is cheap. The (returned) result list is associated
/// with the operation whose results these are, and thus extends the lifetime of
/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+class MLIR_PYTHON_API_EXPORTED PyOpResultList
+ : public Sliceable<PyOpResultList, PyOpResult> {
public:
static constexpr const char *pyClassName = "OpResultList";
using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
@@ -1940,7 +1944,8 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
};
/// Python wrapper for MlirBlockArgument.
-class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
+class MLIR_PYTHON_API_EXPORTED PyBlockArgument
+ : public PyConcreteValue<PyBlockArgument> {
public:
static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
static constexpr const char *pyClassName = "BlockArgument";
@@ -1979,7 +1984,7 @@ class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
/// elements, random access is cheap. The argument list is associated with the
/// operation that contains the block (detached blocks are not allowed in
/// Python bindings) and extends its lifetime.
-class PyBlockArgumentList
+class MLIR_PYTHON_API_EXPORTED PyBlockArgumentList
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
public:
static constexpr const char *pyClassName = "BlockArgumentList";
@@ -2032,7 +2037,8 @@ class PyBlockArgumentList
/// elements, random access is cheap. The (returned) operand list is associated
/// with the operation whose operands these are, and thus extends the lifetime
/// of this operation.
-class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
+class MLIR_PYTHON_API_EXPORTED PyOpOperandList
+ : public Sliceable<PyOpOperandList, PyValue> {
public:
static constexpr const char *pyClassName = "OpOperandList";
using SliceableT = Sliceable<PyOpOperandList, PyValue>;
@@ -2090,7 +2096,8 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
/// elements, random access is cheap. The (returned) successor list is
/// associated with the operation whose successors these are, and thus extends
/// the lifetime of this operation.
-class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
+class MLIR_PYTHON_API_EXPORTED PyOpSuccessors
+ : public Sliceable<PyOpSuccessors, PyBlock> {
public:
static constexpr const char *pyClassName = "OpSuccessors";
@@ -2138,7 +2145,8 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
/// elements, random access is cheap. The (returned) successor list is
/// associated with the operation and block whose successors these are, and thus
/// extends the lifetime of this operation and block.
-class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
+class MLIR_PYTHON_API_EXPORTED PyBlockSuccessors
+ : public Sliceable<PyBlockSuccessors, PyBlock> {
public:
static constexpr const char *pyClassName = "BlockSuccessors";
@@ -2180,7 +2188,8 @@ class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
/// WARNING: This Sliceable is more expensive than the others here because
/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
/// operands) anew for each indexed access.
-class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
+class MLIR_PYTHON_API_EXPORTED PyBlockPredecessors
+ : public Sliceable<PyBlockPredecessors, PyBlock> {
public:
static constexpr const char *pyClassName = "BlockPredecessors";
@@ -2218,7 +2227,7 @@ class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
/// A list of operation attributes. Can be indexed by name, producing
/// attributes, or by index, producing named attributes.
-class PyOpAttributeMap {
+class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
public:
PyOpAttributeMap(PyOperationRef operation)
: operation(std::move(operation)) {}
@@ -2354,7 +2363,7 @@ class PyOpAttributeMap {
PyOperationRef operation;
};
-MlirValue getUniqueResult(MlirOperation operation);
+MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
} // namespace python
} // namespace mlir
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index ba9642cf2c6a2..cd0cfbc7d61d8 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -14,7 +14,8 @@
namespace mlir {
/// Shaped Type Interface - ShapedType
-class PyShapedType : public python::PyConcreteType<PyShapedType> {
+class MLIR_PYTHON_API_EXPORTED PyShapedType
+ : public python::PyConcreteType<PyShapedType> {
public:
static const IsAFunctionTy isaFunction;
static constexpr const char *pyClassName = "ShapedType";
>From 661b9e8c3d8744d9b6ff443538ebd4b66decfb47 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 13:37:11 -0800
Subject: [PATCH 15/27] globals doesn't work
---
mlir/cmake/modules/AddMLIRPython.cmake | 18 ++++++++++++------
mlir/examples/standalone/pyproject.toml | 3 +++
.../standalone/test/python/smoketest.py | 19 +++----------------
mlir/include/mlir/Bindings/Python/Globals.h | 2 --
mlir/lib/Bindings/Python/Globals.cpp | 18 ++++++++++++------
mlir/test/Examples/standalone/test.wheel.toy | 7 ++-----
6 files changed, 32 insertions(+), 35 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 8ec9304421b54..90b26aad03828 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -317,6 +317,10 @@ function(build_nanobind_lib)
set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(NB_LIBRARY_TARGET_NAME "${NB_LIBRARY_TARGET_NAME}" PARENT_SCOPE)
nanobind_build_library(${NB_LIBRARY_TARGET_NAME} AS_SYSINCLUDE)
+ target_compile_definitions(${NB_LIBRARY_TARGET_NAME}
+ PRIVATE
+ NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ )
# nanobind configures with LTO for shared build which doesn't work everywhere
# (see https://github.com/llvm/llvm-project/issues/139602).
if(NOT LLVM_ENABLE_LTO)
@@ -365,6 +369,10 @@ function(add_mlir_python_modules name)
"COMMON_CAPI_LINK_LIBS;DECLARED_SOURCES"
${ARGN})
+ if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
+ endif()
+
# This call sets NB_LIBRARY_TARGET_NAME.
build_nanobind_lib(
INSTALL_COMPONENT ${name}
@@ -420,6 +428,8 @@ function(add_mlir_python_modules name)
get_target_property(_source_type ${sources_target} mlir_python_SOURCES_TYPE)
if(_source_type STREQUAL "support")
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
+ # Use a similar mechanism as nanobind to help the runtime loader pick the correct lib.
+ set(_module_name "${_module_name}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(_extension_target "${name}.extension.${_module_name}.dso")
add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
INSTALL_COMPONENT ${name}
@@ -844,10 +854,6 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
set(eh_rtti_enable -frtti -fexceptions)
endif ()
- if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
- set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
- endif()
-
if(ARG_SUPPORT_LIB)
add_library(${libname} SHARED ${ARG_SOURCES})
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
@@ -859,9 +865,9 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
MLIR_CAPI_BUILDING_LIBRARY=1
)
- if (MSVC)
+ if(MSVC)
set_property(TARGET ${libname} PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
- endif ()
+ endif()
else()
nanobind_add_module(${libname}
NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
diff --git a/mlir/examples/standalone/pyproject.toml b/mlir/examples/standalone/pyproject.toml
index c4194153743ef..a90fb417eb426 100644
--- a/mlir/examples/standalone/pyproject.toml
+++ b/mlir/examples/standalone/pyproject.toml
@@ -56,8 +56,11 @@ MLIR_DIR = { env = "MLIR_DIR", default = "" }
# Non-optional
CMAKE_BUILD_TYPE = { env = "CMAKE_BUILD_TYPE", default = "Release" }
MLIR_ENABLE_BINDINGS_PYTHON = "ON"
+
# Effectively non-optional (any downstream project should specify this).
+MLIR_BINDINGS_PYTHON_NB_DOMAIN = "mlir_standalone"
MLIR_PYTHON_PACKAGE_PREFIX = "mlir_standalone"
+
# This specifies the directory in the install directory (i.e., /tmp/pip-wheel/platlib) where _mlir_libs, dialects, etc.
# are installed. Thus, this will be the package location (and the name of the package) that pip assumes is
# the root package.
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index 9c0ada92551af..ec75790fffeb4 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -1,20 +1,7 @@
# RUN: %python %s nanobind | FileCheck %s
-from mlir_standalone.ir import *
-from mlir_standalone.dialects import standalone_nanobind as standalone_d
-with Context():
- standalone_d.register_dialects()
- module = Module.parse(
- """
- %0 = arith.constant 2 : i32
- %1 = standalone.foo %0 : i32
- """
- )
- # CHECK: %[[C:.*]] = arith.constant 2 : i32
- # CHECK: standalone.foo %[[C]] : i32
- print(str(module))
- custom_type = standalone_d.CustomType.get("foo")
- # CHECK: !standalone.custom<"foo">
- print(custom_type)
+import mlir_standalone.ir
+
+import mlir.ir
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 112c7b9b0547f..d9334cb35cc27 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -174,8 +174,6 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
MlirTypeID allocateTypeID() { return typeIDAllocator.allocate(); }
private:
- static PyGlobals *instance;
-
nanobind::ft_mutex mutex;
/// Module name prefixes to search under for dialect implementation modules.
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index ecac571a132f6..7e451c8009809 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -19,6 +19,8 @@
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include <iostream>
+
namespace nb = nanobind;
using namespace mlir;
@@ -26,22 +28,26 @@ using namespace mlir;
// PyGlobals
// -----------------------------------------------------------------------------
+namespace {
+python::PyGlobals *pyGlobalsInstance = nullptr;
+}
+
namespace mlir::python {
-PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
- assert(!instance && "PyGlobals already constructed");
- instance = this;
+ std::cerr << MAKE_MLIR_PYTHON_QUALNAME("dialects") << "\n";
+ assert(!pyGlobalsInstance && "PyGlobals already constructed");
+ pyGlobalsInstance = this;
// The default search path include {mlir.}dialects, where {mlir.} is the
// package prefix configured at compile time.
dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
}
-PyGlobals::~PyGlobals() { instance = nullptr; }
+PyGlobals::~PyGlobals() { pyGlobalsInstance = nullptr; }
PyGlobals &PyGlobals::get() {
- assert(instance && "PyGlobals is null");
- return *instance;
+ assert(pyGlobalsInstance && "PyGlobals is null");
+ return *pyGlobalsInstance;
}
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy
index c8d188a3cacd0..55847f7430648 100644
--- a/mlir/test/Examples/standalone/test.wheel.toy
+++ b/mlir/test/Examples/standalone/test.wheel.toy
@@ -1,10 +1,6 @@
# There's no real issue with windows here, it's just that some CMake generated paths for targets end up being longer
# than 255 chars when combined with the fact that pip wants to install into a tmp directory buried under
# C/Users/ContainerAdministrator/AppData/Local/Temp.
-# UNSUPPORTED: target={{.*(windows).*}}
-# REQUIRES: expensive_checks
-# REQUIRES: non-shared-libs-build
-# REQUIRES: bindings-python
# RUN: export CMAKE_BUILD_TYPE=%cmake_build_type
# RUN: export CMAKE_CXX_COMPILER=%host_cxx
@@ -15,7 +11,8 @@
# RUN: export LLVM_USE_LINKER=%llvm_use_linker
# RUN: export MLIR_DIR="%mlir_cmake_dir"
-# RUN: %python -m pip wheel "%mlir_src_root/examples/standalone" -w "%mlir_obj_root/wheelhouse" -v | tee %t
+# RUN: %python -m pip install scikit-build-core
+# RUN: %python -m pip wheel "%mlir_src_root/examples/standalone" -w "%mlir_obj_root/wheelhouse" -v --no-build-isolation | tee %t
# RUN: rm -rf "%mlir_obj_root/standalone-python-bindings-install"
# RUN: %python -m pip install standalone_python_bindings -f "%mlir_obj_root/wheelhouse" --target "%mlir_obj_root/standalone-python-bindings-install" -v | tee -a %t
>From b6c1b1d8086fcdf104b3884463d7f34d04f6914c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 17:00:17 -0800
Subject: [PATCH 16/27] works
---
.../examples/standalone/python/CMakeLists.txt | 1 +
.../python/StandaloneExtensionNanobind.cpp | 6 +-
.../standalone/test/python/smoketest.py | 21 ++-
mlir/include/mlir/Bindings/Python/Globals.h | 5 +-
mlir/include/mlir/Bindings/Python/IRCore.h | 43 ++++--
mlir/include/mlir/Bindings/Python/IRTypes.h | 8 +-
mlir/lib/Bindings/Python/Globals.cpp | 26 ++--
mlir/lib/Bindings/Python/IRAffine.cpp | 34 +++--
mlir/lib/Bindings/Python/IRAttributes.cpp | 18 ++-
mlir/lib/Bindings/Python/IRCore.cpp | 28 ++--
mlir/lib/Bindings/Python/IRInterfaces.cpp | 4 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 37 +++--
mlir/lib/Bindings/Python/MainModule.cpp | 133 ++++++++++--------
mlir/lib/Bindings/Python/Pass.cpp | 39 +++--
mlir/lib/Bindings/Python/Pass.h | 3 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 43 +++---
mlir/lib/Bindings/Python/Rewrite.h | 4 +-
mlir/python/CMakeLists.txt | 1 +
mlir/test/Examples/standalone/test.wheel.toy | 14 +-
.../python/lib/PythonTestModuleNanobind.cpp | 13 +-
20 files changed, 303 insertions(+), 178 deletions(-)
diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index edaedf18cc843..d3b3aeadb6396 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -3,6 +3,7 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `mlir_standalone`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
+add_compile_definitions("MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
################################################################################
diff --git a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
index 37737cd89ee1e..c568369913595 100644
--- a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
@@ -17,7 +17,8 @@
namespace nb = nanobind;
-struct PyCustomType : mlir::python::PyConcreteType<PyCustomType> {
+struct PyCustomType
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyCustomType> {
static constexpr IsAFunctionTy isaFunction = mlirStandaloneTypeIsACustomType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirStandaloneCustomTypeGetTypeID;
@@ -28,7 +29,8 @@ struct PyCustomType : mlir::python::PyConcreteType<PyCustomType> {
c.def_static(
"get",
[](const std::string &value,
- mlir::python::DefaultingPyMlirContext context) {
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context) {
return PyCustomType(
context->getRef(),
mlirStandaloneCustomTypeGet(
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index ec75790fffeb4..9132bab75cfcc 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -1,7 +1,24 @@
# RUN: %python %s nanobind | FileCheck %s
+import sys
+from mlir_standalone.ir import *
+from mlir_standalone.dialects import standalone_nanobind as standalone_d
+with Context():
+ standalone_d.register_dialects()
+ module = Module.parse(
+ """
+ %0 = arith.constant 2 : i32
+ %1 = standalone.foo %0 : i32
+ """
+ )
+ # CHECK: %[[C:.*]] = arith.constant 2 : i32
+ # CHECK: standalone.foo %[[C]] : i32
+ print(str(module))
-import mlir_standalone.ir
+ custom_type = standalone_d.CustomType.get("foo")
+ # CHECK: !standalone.custom<"foo">
+ print(custom_type)
-import mlir.ir
+if sys.argv[1] == "test-upstream":
+ from mlir.ir import *
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index d9334cb35cc27..5548a716cbe21 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -28,7 +28,7 @@
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Globals that are always accessible once the extension has been initialized.
/// Methods of this class are thread-safe.
class MLIR_PYTHON_API_EXPORTED PyGlobals {
@@ -174,6 +174,8 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
MlirTypeID allocateTypeID() { return typeIDAllocator.allocate(); }
private:
+ static PyGlobals *instance;
+
nanobind::ft_mutex mutex;
/// Module name prefixes to search under for dialect implementation modules.
@@ -195,6 +197,7 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
TracebackLoc tracebackLoc;
TypeIDAllocator typeIDAllocator;
};
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index ceedeb691eb58..7ed0a9f63bfda 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -33,6 +33,7 @@
namespace mlir {
namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyBlock;
class PyDiagnostic;
@@ -325,6 +326,26 @@ class MLIR_PYTHON_API_EXPORTED PyLocation : public BaseContextObject {
MlirLocation loc;
};
+enum PyMlirDiagnosticSeverity : std::underlying_type<
+ MlirDiagnosticSeverity>::type {
+ MlirDiagnosticError = MlirDiagnosticError,
+ MlirDiagnosticWarning = MlirDiagnosticWarning,
+ MlirDiagnosticNote = MlirDiagnosticNote,
+ MlirDiagnosticRemark = MlirDiagnosticRemark
+};
+
+enum PyMlirWalkResult : std::underlying_type<MlirWalkResult>::type {
+ MlirWalkResultAdvance = MlirWalkResultAdvance,
+ MlirWalkResultInterrupt = MlirWalkResultInterrupt,
+ MlirWalkResultSkip = MlirWalkResultSkip
+};
+
+/// Traversal order for operation walk.
+enum PyMlirWalkOrder : std::underlying_type<MlirWalkOrder>::type {
+ MlirWalkPreOrder = MlirWalkPreOrder,
+ MlirWalkPostOrder = MlirWalkPostOrder
+};
+
/// Python class mirroring the C MlirDiagnostic struct. Note that these structs
/// are only valid for the duration of a diagnostic callback and attempting
/// to access them outside of that will raise an exception. This applies to
@@ -334,7 +355,7 @@ class MLIR_PYTHON_API_EXPORTED PyDiagnostic {
PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
void invalidate();
bool isValid() { return valid; }
- MlirDiagnosticSeverity getSeverity();
+ PyMlirDiagnosticSeverity getSeverity();
PyLocation getLocation();
nanobind::str getMessage();
nanobind::tuple getNotes();
@@ -342,7 +363,7 @@ class MLIR_PYTHON_API_EXPORTED PyDiagnostic {
/// Materialized diagnostic information. This is safe to access outside the
/// diagnostic callback.
struct DiagnosticInfo {
- MlirDiagnosticSeverity severity;
+ PyMlirDiagnosticSeverity severity;
PyLocation location;
std::string message;
std::vector<DiagnosticInfo> notes;
@@ -573,8 +594,8 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
std::optional<int64_t> bytecodeVersion);
// Implement the walk method.
- void walk(std::function<MlirWalkResult(MlirOperation)> callback,
- MlirWalkOrder walkOrder);
+ void walk(std::function<PyMlirWalkResult(MlirOperation)> callback,
+ PyMlirWalkOrder walkOrder);
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
@@ -2364,6 +2385,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
};
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
@@ -2371,11 +2393,16 @@ namespace nanobind {
namespace detail {
template <>
-struct type_caster<mlir::python::DefaultingPyMlirContext>
- : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
+struct type_caster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext>
+ : MlirDefaultingCaster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext> {
+};
template <>
-struct type_caster<mlir::python::DefaultingPyLocation>
- : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
+struct type_caster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation>
+ : MlirDefaultingCaster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation> {};
} // namespace detail
} // namespace nanobind
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index cd0cfbc7d61d8..87e0e10764bd8 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -12,10 +12,11 @@
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace mlir {
-
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Shaped Type Interface - ShapedType
class MLIR_PYTHON_API_EXPORTED PyShapedType
- : public python::PyConcreteType<PyShapedType> {
+ : public PyConcreteType<PyShapedType> {
public:
static const IsAFunctionTy isaFunction;
static constexpr const char *pyClassName = "ShapedType";
@@ -26,7 +27,8 @@ class MLIR_PYTHON_API_EXPORTED PyShapedType
private:
void requireHasRank();
};
-
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
} // namespace mlir
#endif // MLIR_BINDINGS_PYTHON_IRTYPES_H
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index 7e451c8009809..e2e8693ba45f3 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -19,8 +19,6 @@
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include <iostream>
-
namespace nb = nanobind;
using namespace mlir;
@@ -28,26 +26,24 @@ using namespace mlir;
// PyGlobals
// -----------------------------------------------------------------------------
-namespace {
-python::PyGlobals *pyGlobalsInstance = nullptr;
-}
-
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
- std::cerr << MAKE_MLIR_PYTHON_QUALNAME("dialects") << "\n";
- assert(!pyGlobalsInstance && "PyGlobals already constructed");
- pyGlobalsInstance = this;
+ assert(!instance && "PyGlobals already constructed");
+ instance = this;
// The default search path include {mlir.}dialects, where {mlir.} is the
// package prefix configured at compile time.
dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
}
-PyGlobals::~PyGlobals() { pyGlobalsInstance = nullptr; }
+PyGlobals::~PyGlobals() { instance = nullptr; }
PyGlobals &PyGlobals::get() {
- assert(pyGlobalsInstance && "PyGlobals is null");
- return *pyGlobalsInstance;
+ assert(instance && "PyGlobals is null");
+ return *instance;
}
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
@@ -278,4 +274,6 @@ bool PyGlobals::TracebackLoc::isUserTracebackFilename(
}
return isUserTracebackFilenameCache[file];
}
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 624d8f0fa57ce..ce235470bbdc7 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -30,7 +30,7 @@
namespace nb = nanobind;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using llvm::SmallVector;
using llvm::StringRef;
@@ -80,7 +80,9 @@ static bool isPermutation(const std::vector<PermutationTy> &permutation) {
return true;
}
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
/// and should be castable from it. Intermediate hierarchy classes can be
@@ -358,7 +360,9 @@ class PyAffineCeilDivExpr
}
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
return mlirAffineExprEqual(affineExpr, other.affineExpr);
@@ -380,7 +384,9 @@ PyAffineExpr PyAffineExpr::createFromCapsule(const nb::object &capsule) {
//------------------------------------------------------------------------------
// PyAffineMap and utilities.
//------------------------------------------------------------------------------
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// A list of expressions contained in an affine map. Internally these are
/// stored as a consecutive array leading to inexpensive random access. Both
@@ -416,7 +422,9 @@ class PyAffineMapExprList
PyAffineMap affineMap;
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
bool PyAffineMap::operator==(const PyAffineMap &other) const {
return mlirAffineMapEqual(affineMap, other.affineMap);
@@ -438,7 +446,9 @@ PyAffineMap PyAffineMap::createFromCapsule(const nb::object &capsule) {
//------------------------------------------------------------------------------
// PyIntegerSet and utilities.
//------------------------------------------------------------------------------
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyIntegerSetConstraint {
public:
@@ -492,7 +502,9 @@ class PyIntegerSetConstraintList
PyIntegerSet set;
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
return mlirIntegerSetEqual(integerSet, other.integerSet);
@@ -511,7 +523,9 @@ PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) {
rawIntegerSet);
}
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRAffine(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of PyAffineExpr and derived classes.
@@ -998,4 +1012,6 @@ void populateIRAffine(nb::module_ &m) {
PyIntegerSetConstraint::bind(m);
PyIntegerSetConstraintList::bind(m);
}
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index e39eabdb136b8..a4d308bf049d8 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -24,7 +24,7 @@
namespace nb = nanobind;
using namespace nanobind::literals;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using llvm::SmallVector;
@@ -121,7 +121,9 @@ subsequent processing.
type or if the buffer does not meet expectations.
)";
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
struct nb_buffer_info {
void *ptr = nullptr;
@@ -1745,7 +1747,9 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
throw nb::type_error(msg.c_str());
}
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
void PyStringAttribute::bindDerived(ClassTy &c) {
c.def_static(
@@ -1791,7 +1795,9 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
"Returns the value of the string attribute as `bytes`");
}
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
@@ -1846,4 +1852,6 @@ void populateIRAttributes(nb::module_ &m) {
PyStridedLayoutAttribute::bind(m);
registerMLIRError();
}
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index fc8743599508d..069e177708afc 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -31,13 +31,14 @@
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlir;
-using namespace mlir::python;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
//------------------------------------------------------------------------------
// PyMlirContext
//------------------------------------------------------------------------------
@@ -169,7 +170,8 @@ MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
if (self->ctx->emitErrorDiagnostics)
return mlirLogicalResultFailure();
- if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
+ if (mlirDiagnosticGetSeverity(diag) !=
+ MlirDiagnosticSeverity::MlirDiagnosticError)
return mlirLogicalResultFailure();
self->errors.emplace_back(PyDiagnostic(diag).getInfo());
@@ -356,9 +358,10 @@ void PyDiagnostic::checkValid() {
}
}
-MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
+PyMlirDiagnosticSeverity PyDiagnostic::getSeverity() {
checkValid();
- return mlirDiagnosticGetSeverity(diagnostic);
+ return static_cast<PyMlirDiagnosticSeverity>(
+ mlirDiagnosticGetSeverity(diagnostic));
}
PyLocation PyDiagnostic::getLocation() {
@@ -672,12 +675,12 @@ void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
}
void PyOperationBase::walk(
- std::function<MlirWalkResult(MlirOperation)> callback,
- MlirWalkOrder walkOrder) {
+ std::function<PyMlirWalkResult(MlirOperation)> callback,
+ PyMlirWalkOrder walkOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
struct UserData {
- std::function<MlirWalkResult(MlirOperation)> callback;
+ std::function<PyMlirWalkResult(MlirOperation)> callback;
bool gotException;
std::string exceptionWhat;
nb::object exceptionType;
@@ -687,7 +690,7 @@ void PyOperationBase::walk(
void *userData) {
UserData *calleeUserData = static_cast<UserData *>(userData);
try {
- return (calleeUserData->callback)(op);
+ return static_cast<MlirWalkResult>((calleeUserData->callback)(op));
} catch (nb::python_error &e) {
calleeUserData->gotException = true;
calleeUserData->exceptionWhat = std::string(e.what());
@@ -695,7 +698,8 @@ void PyOperationBase::walk(
return MlirWalkResult::MlirWalkResultInterrupt;
}
};
- mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
+ mlirOperationWalk(operation, walkCallback, &userData,
+ static_cast<MlirWalkOrder>(walkOrder));
if (userData.gotException) {
std::string message("Exception raised in callback: ");
message.append(userData.exceptionWhat);
@@ -1685,4 +1689,6 @@ void registerMLIRErrorInCore() {
}
});
}
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 78d1f977b2ebc..09112d4989ae4 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -25,7 +25,7 @@ namespace nb = nanobind;
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
constexpr static const char *constructorDoc =
R"(Creates an interface from a given operation/opview object or from a
subclass of OpView. Raises ValueError if the operation does not implement the
@@ -469,6 +469,6 @@ void populateIRInterfaces(nb::module_ &m) {
PyShapedTypeComponents::bind(m);
PyInferShapedTypeOpInterface::bind(m);
}
-
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 7d9a0f16c913a..62fb2ef207d58 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -20,12 +20,14 @@
namespace nb = nanobind;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using llvm::SmallVector;
using llvm::Twine;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Checks whether the given type is an integer or float type.
static int mlirTypeIsAIntegerOrFloat(MlirType type) {
@@ -508,10 +510,12 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
}
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
// Shaped Type Interface - ShapedType
-void mlir::PyShapedType::bindDerived(ClassTy &c) {
+void PyShapedType::bindDerived(ClassTy &c) {
c.def_prop_ro(
"element_type",
[](PyShapedType &self) -> nb::typed<nb::object, PyType> {
@@ -616,17 +620,18 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
"shaped types.");
}
-void mlir::PyShapedType::requireHasRank() {
+void PyShapedType::requireHasRank() {
if (!mlirShapedTypeHasRank(*this)) {
throw nb::value_error(
"calling this method requires that the type has a rank.");
}
}
-const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction =
- mlirTypeIsAShaped;
+const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Vector Type subclass - VectorType.
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
@@ -1098,10 +1103,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
}
};
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
/// Opaque Type subclass - OpaqueType.
class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
public:
@@ -1141,9 +1142,13 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
}
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
@@ -1177,4 +1182,6 @@ void populateIRTypes(nb::module_ &m) {
PyOpaqueType::bind(m);
registerMLIRError();
}
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index f72775cc0b83a..392144ec5f0b7 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -16,7 +16,7 @@
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
static const char kModuleParseDocstring[] =
R"(Parses a module's assembly format from a string.
@@ -35,6 +35,56 @@ in `exceptions`. `exceptions` can be either a single operation or a list of
operations.
)";
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+/// Wrapper for the global LLVM debugging flag.
+struct PyGlobalDebugFlag {
+ static void set(nanobind::object &o, bool enable) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+ }
+
+ static bool get(const nanobind::object &) {
+ nanobind::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+ }
+
+ static void bind(nanobind::module_ &m) {
+ // Debug flags.
+ nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+ .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
+ .def_static(
+ "set_types",
+ [](const std::string &type) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugType(type.c_str());
+ },
+ nanobind::arg("types"),
+ "Sets specific debug types to be produced by LLVM.")
+ .def_static(
+ "set_types",
+ [](const std::vector<std::string> &types) {
+ std::vector<const char *> pointers;
+ pointers.reserve(types.size());
+ for (const std::string &str : types)
+ pointers.push_back(str.c_str());
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
+ },
+ nanobind::arg("types"),
+ "Sets multiple specific debug types to be produced by LLVM.");
+ }
+
+private:
+ static nanobind::ft_mutex mutex;
+};
+nanobind::ft_mutex PyGlobalDebugFlag::mutex;
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
namespace {
// see
// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
@@ -185,51 +235,6 @@ maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
return {ref, mlirLoc};
}
-
-/// Wrapper for the global LLVM debugging flag.
-struct PyGlobalDebugFlag {
- static void set(nanobind::object &o, bool enable) {
- nanobind::ft_lock_guard lock(mutex);
- mlirEnableGlobalDebug(enable);
- }
-
- static bool get(const nanobind::object &) {
- nanobind::ft_lock_guard lock(mutex);
- return mlirIsGlobalDebugEnabled();
- }
-
- static void bind(nanobind::module_ &m) {
- // Debug flags.
- nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
- .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
- .def_static(
- "set_types",
- [](const std::string &type) {
- nanobind::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugType(type.c_str());
- },
- nanobind::arg("types"),
- "Sets specific debug types to be produced by LLVM.")
- .def_static(
- "set_types",
- [](const std::vector<std::string> &types) {
- std::vector<const char *> pointers;
- pointers.reserve(types.size());
- for (const std::string &str : types)
- pointers.push_back(str.c_str());
- nanobind::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
- },
- nanobind::arg("types"),
- "Sets multiple specific debug types to be produced by LLVM.");
- }
-
-private:
- static nanobind::ft_mutex mutex;
-};
-
-nanobind::ft_mutex PyGlobalDebugFlag::mutex;
} // namespace
//------------------------------------------------------------------------------
@@ -242,20 +247,20 @@ static void populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
// Enums.
//----------------------------------------------------------------------------
- nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
- .value("ERROR", MlirDiagnosticError)
- .value("WARNING", MlirDiagnosticWarning)
- .value("NOTE", MlirDiagnosticNote)
- .value("REMARK", MlirDiagnosticRemark);
+ nb::enum_<PyMlirDiagnosticSeverity>(m, "DiagnosticSeverity")
+ .value("ERROR", PyMlirDiagnosticSeverity::MlirDiagnosticError)
+ .value("WARNING", PyMlirDiagnosticSeverity::MlirDiagnosticWarning)
+ .value("NOTE", PyMlirDiagnosticSeverity::MlirDiagnosticNote)
+ .value("REMARK", PyMlirDiagnosticSeverity::MlirDiagnosticRemark);
- nb::enum_<MlirWalkOrder>(m, "WalkOrder")
- .value("PRE_ORDER", MlirWalkPreOrder)
- .value("POST_ORDER", MlirWalkPostOrder);
+ nb::enum_<PyMlirWalkOrder>(m, "WalkOrder")
+ .value("PRE_ORDER", PyMlirWalkOrder::MlirWalkPreOrder)
+ .value("POST_ORDER", PyMlirWalkOrder::MlirWalkPostOrder);
- nb::enum_<MlirWalkResult>(m, "WalkResult")
- .value("ADVANCE", MlirWalkResultAdvance)
- .value("INTERRUPT", MlirWalkResultInterrupt)
- .value("SKIP", MlirWalkResultSkip);
+ nb::enum_<PyMlirWalkResult>(m, "WalkResult")
+ .value("ADVANCE", PyMlirWalkResult::MlirWalkResultAdvance)
+ .value("INTERRUPT", PyMlirWalkResult::MlirWalkResultInterrupt)
+ .value("SKIP", PyMlirWalkResult::MlirWalkResultSkip);
//----------------------------------------------------------------------------
// Mapping of Diagnostics.
@@ -1186,7 +1191,7 @@ static void populateIRCore(nb::module_ &m) {
Note:
After erasing, any Python references to the operation become invalid.)")
.def("walk", &PyOperationBase::walk, nb::arg("callback"),
- nb::arg("walk_order") = MlirWalkPostOrder,
+ nb::arg("walk_order") = PyMlirWalkOrder::MlirWalkPostOrder,
// clang-format off
nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
// clang-format on
@@ -2305,12 +2310,16 @@ static void populateIRCore(nb::module_ &m) {
PyAttrBuilderMap::bind(m);
}
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRAffine(nb::module_ &m);
void populateIRAttributes(nb::module_ &m);
void populateIRInterfaces(nb::module_ &m);
void populateIRTypes(nb::module_ &m);
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
// -----------------------------------------------------------------------------
// Module initialization.
@@ -2453,5 +2462,5 @@ NB_MODULE(_mlir, m) {
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
registerMLIRError();
- registerMLIRErrorInCore();
+ // registerMLIRErrorInCore();
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 3cfdfe49b4e3e..e35923553e0a1 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -19,9 +19,11 @@
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Owning Wrapper around a PassManager.
class PyPassManager {
@@ -53,23 +55,29 @@ class PyPassManager {
MlirPassManager passManager;
};
-} // namespace
+enum PyMlirPassDisplayMode : std::underlying_type<MlirPassDisplayMode>::type {
+ MLIR_PASS_DISPLAY_MODE_LIST = MLIR_PASS_DISPLAY_MODE_LIST,
+ MLIR_PASS_DISPLAY_MODE_PIPELINE = MLIR_PASS_DISPLAY_MODE_PIPELINE
+};
+
+struct PyMlirExternalPass : MlirExternalPass {};
/// Create the `mlir.passmanager` here.
-void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+void populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of enumerated types
//----------------------------------------------------------------------------
- nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode")
+ nb::enum_<PyMlirPassDisplayMode>(m, "PassDisplayMode")
.value("LIST", MLIR_PASS_DISPLAY_MODE_LIST)
.value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE);
//----------------------------------------------------------------------------
// Mapping of MlirExternalPass
//----------------------------------------------------------------------------
- nb::class_<MlirExternalPass>(m, "ExternalPass")
- .def("signal_pass_failure",
- [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
+ nb::class_<PyMlirExternalPass>(m, "ExternalPass")
+ .def("signal_pass_failure", [](PyMlirExternalPass pass) {
+ mlirExternalPassSignalFailure(pass);
+ });
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
@@ -148,11 +156,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"Enable pass timing.")
.def(
"enable_statistics",
- [](PyPassManager &passManager, MlirPassDisplayMode displayMode) {
- mlirPassManagerEnableStatistics(passManager.get(), displayMode);
+ [](PyPassManager &passManager, PyMlirPassDisplayMode displayMode) {
+ mlirPassManagerEnableStatistics(
+ passManager.get(),
+ static_cast<MlirPassDisplayMode>(displayMode));
},
- "displayMode"_a =
- MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE,
+ "displayMode"_a = MLIR_PASS_DISPLAY_MODE_PIPELINE,
"Enable pass statistics.")
.def_static(
"parse",
@@ -211,7 +220,8 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
};
callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
- nb::handle(static_cast<PyObject *>(userData))(op, pass);
+ nb::handle(static_cast<PyObject *>(userData))(
+ op, PyMlirExternalPass{pass.ptr});
};
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
@@ -256,3 +266,6 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"be passed to `parse` for round-tripping.");
registerMLIRError();
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index 0221bd10e723e..1a311666ebecd 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -13,8 +13,9 @@
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populatePassManagerSubmodule(nanobind::module_ &m);
+}
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 4700120422ddc..56068cd785573 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -22,9 +22,11 @@
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyPatternRewriter {
public:
@@ -60,6 +62,8 @@ class PyPatternRewriter {
PyMlirContextRef ctx;
};
+struct PyMlirPDLResultList : MlirPDLResultList {};
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
@@ -118,7 +122,7 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(PyPatternRewriter(rewriter), results,
+ f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
@@ -133,7 +137,7 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(PyPatternRewriter(rewriter), results,
+ f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
@@ -223,10 +227,8 @@ class PyRewritePatternSet {
MlirContext ctx;
};
-} // namespace
-
/// Create the `mlir.rewrite` here.
-void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+void populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
@@ -293,10 +295,10 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
- nb::class_<MlirPDLResultList>(m, "PDLResultList")
+ nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
.def(
"append",
- [](MlirPDLResultList results, const PyValue &value) {
+ [](PyMlirPDLResultList results, const PyValue &value) {
mlirPDLResultListPushBackValue(results, value);
},
// clang-format off
@@ -305,7 +307,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
)
.def(
"append",
- [](MlirPDLResultList results, const PyOperation &op) {
+ [](PyMlirPDLResultList results, const PyOperation &op) {
mlirPDLResultListPushBackOperation(results, op);
},
// clang-format off
@@ -314,7 +316,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
)
.def(
"append",
- [](MlirPDLResultList results, const PyType &type) {
+ [](PyMlirPDLResultList results, const PyType &type) {
mlirPDLResultListPushBackType(results, type);
},
// clang-format off
@@ -323,7 +325,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
)
.def(
"append",
- [](MlirPDLResultList results, const PyAttribute &attr) {
+ [](PyMlirPDLResultList results, const PyAttribute &attr) {
mlirPDLResultListPushBackAttribute(results, attr);
},
// clang-format off
@@ -333,9 +335,9 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
- [](PyPDLPatternModule &self, MlirModule module) {
- new (&self)
- PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
+ [](PyPDLPatternModule &self, PyModule &module) {
+ new (&self) PyPDLPatternModule(
+ mlirPDLPatternModuleFromModule(module.get()));
},
// clang-format off
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
@@ -394,9 +396,9 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"results.")
.def(
"apply_patterns_and_fold_greedily",
- [](PyModule &module, MlirFrozenRewritePatternSet set) {
+ [](PyModule &module, PyFrozenRewritePatternSet &set) {
auto status =
- mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
+ mlirApplyPatternsAndFoldGreedily(module.get(), set.get(), {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
@@ -425,9 +427,9 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"results.")
.def(
"apply_patterns_and_fold_greedily",
- [](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
+ [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
- op.getOperation(), set, {});
+ op.getOperation(), set.get(), {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
@@ -450,3 +452,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"Applies the given patterns to the given op by a fast walk-based "
"driver.");
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index f8ffdc7bdc458..d287f19187708 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -13,9 +13,9 @@
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateRewriteSubmodule(nanobind::module_ &m);
-
+}
} // namespace python
} // namespace mlir
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b22d2ec75b3ba..2d2ae26bf3b28 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -3,6 +3,7 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `MLIR_PYTHON_PACKAGE_PREFIX.`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
+add_compile_definitions("MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy
index 55847f7430648..91fed38e28612 100644
--- a/mlir/test/Examples/standalone/test.wheel.toy
+++ b/mlir/test/Examples/standalone/test.wheel.toy
@@ -1,6 +1,10 @@
# There's no real issue with windows here, it's just that some CMake generated paths for targets end up being longer
# than 255 chars when combined with the fact that pip wants to install into a tmp directory buried under
# C/Users/ContainerAdministrator/AppData/Local/Temp.
+# UNSUPPORTED: target={{.*(windows).*}}
+# REQUIRES: expensive_checks
+# REQUIRES: non-shared-libs-build
+# REQUIRES: bindings-python
# RUN: export CMAKE_BUILD_TYPE=%cmake_build_type
# RUN: export CMAKE_CXX_COMPILER=%host_cxx
@@ -11,21 +15,21 @@
# RUN: export LLVM_USE_LINKER=%llvm_use_linker
# RUN: export MLIR_DIR="%mlir_cmake_dir"
-# RUN: %python -m pip install scikit-build-core
-# RUN: %python -m pip wheel "%mlir_src_root/examples/standalone" -w "%mlir_obj_root/wheelhouse" -v --no-build-isolation | tee %t
+# RUN: %python -m pip wheel "%mlir_src_root/examples/standalone" -w "%mlir_obj_root/wheelhouse" -v | tee %t
# RUN: rm -rf "%mlir_obj_root/standalone-python-bindings-install"
# RUN: %python -m pip install standalone_python_bindings -f "%mlir_obj_root/wheelhouse" --target "%mlir_obj_root/standalone-python-bindings-install" -v | tee -a %t
-# RUN: export PYTHONPATH="%mlir_obj_root/standalone-python-bindings-install"
-# RUN: %python "%mlir_src_root/examples/standalone/test/python/smoketest.py" nanobind | tee -a %t
+# RUN: export PYTHONPATH="%mlir_obj_root/standalone-python-bindings-install:%mlir_obj_root/python_packages/mlir_core"
+# RUN: %python "%mlir_src_root/examples/standalone/test/python/smoketest.py" test-upstream 2>&1 | tee -a %t
# RUN: FileCheck --input-file=%t %s
# CHECK: Successfully built standalone-python-bindings
+# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
+
# CHECK: module {
# CHECK: %[[C2:.*]] = arith.constant 2 : i32
# CHECK: %[[V0:.*]] = standalone.foo %[[C2]] : i32
# CHECK: }
-
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index c8b95e2316778..43573cbc305fa 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -27,7 +27,8 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
-struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
+struct PyTestType
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPythonTestTestTypeGetTypeID;
@@ -37,7 +38,8 @@ struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](mlir::python::DefaultingPyMlirContext context) {
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context) {
return PyTestType(context->getRef(),
mlirPythonTestTestTypeGet(context.get()->get()));
},
@@ -45,7 +47,9 @@ struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
}
};
-class PyTestAttr : public mlir::python::PyConcreteAttribute<PyTestAttr> {
+class PyTestAttr
+ : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
+ PyTestAttr> {
public:
static constexpr IsAFunctionTy isaFunction =
mlirAttributeIsAPythonTestTestAttribute;
@@ -57,7 +61,8 @@ class PyTestAttr : public mlir::python::PyConcreteAttribute<PyTestAttr> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](mlir::python::DefaultingPyMlirContext context) {
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context) {
return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
context.get()->get()));
},
>From c5f6787318bee2fd9058104d2518715fa05a28ca Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 24 Dec 2025 10:55:35 -0800
Subject: [PATCH 17/27] try moving MLIR_BINDINGS_PYTHON_NB_DOMAIN compile defn
---
mlir/cmake/modules/AddMLIRPython.cmake | 6 ++++++
mlir/examples/standalone/python/CMakeLists.txt | 1 -
mlir/examples/standalone/test/python/smoketest.py | 2 +-
mlir/python/CMakeLists.txt | 1 -
4 files changed, 7 insertions(+), 3 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 90b26aad03828..0c6426b903ec7 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -369,6 +369,7 @@ function(add_mlir_python_modules name)
"COMMON_CAPI_LINK_LIBS;DECLARED_SOURCES"
${ARGN})
+ # TODO(max): do the same for MLIR_PYTHON_PACKAGE_PREFIX?
if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
endif()
@@ -863,6 +864,7 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
target_compile_definitions(${libname}
PRIVATE
NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
MLIR_CAPI_BUILDING_LIBRARY=1
)
if(MSVC)
@@ -875,6 +877,10 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
NB_SHARED
${ARG_SOURCES}
)
+ target_compile_definitions(${libname}
+ PRIVATE
+ MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ )
endif()
target_link_libraries(${libname} PRIVATE ${nb_library_target_name})
diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index d3b3aeadb6396..edaedf18cc843 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -3,7 +3,6 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `mlir_standalone`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
-add_compile_definitions("MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
################################################################################
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index 9132bab75cfcc..e59f4ed2276d6 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -20,5 +20,5 @@
# CHECK: !standalone.custom<"foo">
print(custom_type)
-if sys.argv[1] == "test-upstream":
+if len(sys.argv) > 1 and sys.argv[1] == "test-upstream":
from mlir.ir import *
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2d2ae26bf3b28..b22d2ec75b3ba 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -3,7 +3,6 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `MLIR_PYTHON_PACKAGE_PREFIX.`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
-add_compile_definitions("MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
>From c26ccbee02355fae5d2c67733f10fad380b1b9c9 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 24 Dec 2025 16:11:18 -0800
Subject: [PATCH 18/27] remove registerError
---
mlir/include/mlir/Bindings/Python/IRCore.h | 19 -------------------
mlir/lib/Bindings/Python/IRAttributes.cpp | 1 -
mlir/lib/Bindings/Python/IRCore.cpp | 16 ----------------
mlir/lib/Bindings/Python/IRTypes.cpp | 1 -
mlir/lib/Bindings/Python/MainModule.cpp | 16 ++++++++++++++--
mlir/lib/Bindings/Python/Pass.cpp | 1 -
6 files changed, 14 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 7ed0a9f63bfda..596ff7828631b 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1348,25 +1348,6 @@ struct MLIR_PYTHON_API_EXPORTED MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
-inline void registerMLIRError() {
- nanobind::register_exception_translator(
- [](const std::exception_ptr &p, void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nanobind::object obj =
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
-}
-
-MLIR_PYTHON_API_EXPORTED void registerMLIRErrorInCore();
-
//------------------------------------------------------------------------------
// Utilities.
//------------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index a4d308bf049d8..f0f0ae9ba741e 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1850,7 +1850,6 @@ void populateIRAttributes(nb::module_ &m) {
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
- registerMLIRError();
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 069e177708afc..26e0128752838 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1673,22 +1673,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
throw std::runtime_error(message);
}
}
-
-void registerMLIRErrorInCore() {
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
-}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 62fb2ef207d58..7350046f428c7 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -1180,7 +1180,6 @@ void populateIRTypes(nb::module_ &m) {
PyTupleType::bind(m);
PyFunctionType::bind(m);
PyOpaqueType::bind(m);
- registerMLIRError();
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 392144ec5f0b7..071f106da04bb 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -2461,6 +2461,18 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
- registerMLIRError();
- // registerMLIRErrorInCore();
+ nanobind::register_exception_translator(
+ [](const std::exception_ptr &p, void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nanobind::object obj =
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index e35923553e0a1..7bfc729568c42 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -264,7 +264,6 @@ void populatePassManagerSubmodule(nb::module_ &m) {
},
"Print the textual representation for this PassManager, suitable to "
"be passed to `parse` for round-tripping.");
- registerMLIRError();
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
>From 95eedfec0d34262b7ef3e111be66e9433061e1d6 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 25 Dec 2025 12:43:39 -0800
Subject: [PATCH 19/27] check standalone
---
mlir/examples/standalone/test/lit.cfg.py | 12 +++++-------
mlir/examples/standalone/test/python/smoketest.py | 8 ++++----
mlir/test/Examples/standalone/test.toy | 1 +
mlir/test/Examples/standalone/test.wheel.toy | 4 +++-
4 files changed, 13 insertions(+), 12 deletions(-)
diff --git a/mlir/examples/standalone/test/lit.cfg.py b/mlir/examples/standalone/test/lit.cfg.py
index e27dddd7fb0b9..89cdd6889a1f2 100644
--- a/mlir/examples/standalone/test/lit.cfg.py
+++ b/mlir/examples/standalone/test/lit.cfg.py
@@ -61,10 +61,8 @@
llvm_config.add_tool_substitutions(tools, tool_dirs)
-llvm_config.with_environment(
- "PYTHONPATH",
- [
- os.path.join(config.mlir_obj_dir, "python_packages", "standalone"),
- ],
- append_path=True,
-)
+python_path = [os.path.join(config.mlir_obj_dir, "python_packages", "standalone")]
+if "PYTHONPATH" in os.environ:
+ python_path += [os.environ["PYTHONPATH"]]
+
+llvm_config.with_environment("PYTHONPATH", python_path, append_path=True)
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index e59f4ed2276d6..ab396327e1c4d 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -1,9 +1,10 @@
-# RUN: %python %s nanobind | FileCheck %s
-import sys
+# RUN: %python %s 2>&1 | FileCheck %s
from mlir_standalone.ir import *
from mlir_standalone.dialects import standalone_nanobind as standalone_d
+# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
+
with Context():
standalone_d.register_dialects()
module = Module.parse(
@@ -20,5 +21,4 @@
# CHECK: !standalone.custom<"foo">
print(custom_type)
-if len(sys.argv) > 1 and sys.argv[1] == "test-upstream":
- from mlir.ir import *
+from mlir.ir import *
diff --git a/mlir/test/Examples/standalone/test.toy b/mlir/test/Examples/standalone/test.toy
index a88c115ebf197..5efd7ca1f30bf 100644
--- a/mlir/test/Examples/standalone/test.toy
+++ b/mlir/test/Examples/standalone/test.toy
@@ -6,6 +6,7 @@
# RUN: -DMLIR_PYTHON_PACKAGE_PREFIX=mlir_standalone \
# RUN: -DPython3_EXECUTABLE=%python \
# RUN: -DPython_EXECUTABLE=%python
+# RUN: export PYTHONPATH="%mlir_obj_root/python_packages/mlir_core"
# RUN: "%cmake_exe" --build . --target check-standalone | tee %t
# RUN: FileCheck --input-file=%t %s
diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy
index 91fed38e28612..f56dc1e6d3a63 100644
--- a/mlir/test/Examples/standalone/test.wheel.toy
+++ b/mlir/test/Examples/standalone/test.wheel.toy
@@ -21,7 +21,7 @@
# RUN: %python -m pip install standalone_python_bindings -f "%mlir_obj_root/wheelhouse" --target "%mlir_obj_root/standalone-python-bindings-install" -v | tee -a %t
# RUN: export PYTHONPATH="%mlir_obj_root/standalone-python-bindings-install:%mlir_obj_root/python_packages/mlir_core"
-# RUN: %python "%mlir_src_root/examples/standalone/test/python/smoketest.py" test-upstream 2>&1 | tee -a %t
+# RUN: %python "%mlir_src_root/examples/standalone/test/python/smoketest.py" 2>&1 | tee -a %t
# RUN: FileCheck --input-file=%t %s
@@ -33,3 +33,5 @@
# CHECK: %[[C2:.*]] = arith.constant 2 : i32
# CHECK: %[[V0:.*]] = standalone.foo %[[C2]] : i32
# CHECK: }
+
+# CHECK: !standalone.custom<"foo">
>From 89bcbbda50b51c304d728a8799b6abe1835eab98 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 25 Dec 2025 13:54:14 -0800
Subject: [PATCH 20/27] comments
---
mlir/cmake/modules/AddMLIRPython.cmake | 19 +++++++++++++------
1 file changed, 13 insertions(+), 6 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 0c6426b903ec7..7c8b4a5c6cbf0 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -371,6 +371,9 @@ function(add_mlir_python_modules name)
# TODO(max): do the same for MLIR_PYTHON_PACKAGE_PREFIX?
if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ message(WARNING "MLIR_BINDINGS_PYTHON_NB_DOMAIN CMake var is not set - setting to a default `mlir`.\
+ It is highly recommend to set this to something unique so that your project's Python bindings do not collide with\
+ others'. See https://github.com/llvm/llvm-project/pull/171775 for more information.")
set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
endif()
@@ -858,13 +861,17 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
if(ARG_SUPPORT_LIB)
add_library(${libname} SHARED ${ARG_SOURCES})
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
- target_link_options(${libname} PRIVATE "-Wl,-z,undefs")
+ # nanobind handles this correctly for MacOS by explicitly setting -U for all the necessary Python symbols
+ # (see https://github.com/wjakob/nanobind/blob/master/cmake/darwin-ld-cpython.sym)
+ # but since we set -z,defs in llvm/cmake/modules/HandleLLVMOptions.cmake:340 for all Linux shlibs
+ # we need to negate it here (we could have our own linux-ld-cpython.sym but that would be too much
+ # maintenance - and this shlib is the only one where we do this).
+ target_link_options(${libname} PRIVATE "LINKER:-z,undefs")
endif()
nanobind_link_options(${libname})
target_compile_definitions(${libname}
PRIVATE
NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
- MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
MLIR_CAPI_BUILDING_LIBRARY=1
)
if(MSVC)
@@ -877,12 +884,12 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
NB_SHARED
${ARG_SOURCES}
)
- target_compile_definitions(${libname}
- PRIVATE
- MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
- )
endif()
target_link_libraries(${libname} PRIVATE ${nb_library_target_name})
+ target_compile_definitions(${libname}
+ PRIVATE
+ MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ )
if (NOT MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES
AND (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL))
>From bc445751b28b6081ba07084cd294c53dbff0654d Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 25 Dec 2025 19:13:46 -0800
Subject: [PATCH 21/27] address comments
---
mlir/lib/Bindings/Python/Rewrite.cpp | 31 ----------------------------
1 file changed, 31 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 56068cd785573..273a4870fec2a 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -394,37 +394,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
// clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
- .def(
- "apply_patterns_and_fold_greedily",
- [](PyModule &module, PyFrozenRewritePatternSet &set) {
- auto status =
- mlirApplyPatternsAndFoldGreedily(module.get(), set.get(), {});
- if (mlirLogicalResultIsFailure(status))
- throw std::runtime_error(
- "pattern application failed to converge");
- },
- "module"_a, "set"_a,
- // clang-format off
- nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
- // clang-format on
- "Applys the given patterns to the given module greedily while "
- "folding "
- "results.")
- .def(
- "apply_patterns_and_fold_greedily",
- [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
- auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
- op.getOperation(), set.get(), {});
- if (mlirLogicalResultIsFailure(status))
- throw std::runtime_error(
- "pattern application failed to converge");
- },
- "op"_a, "set"_a,
- // clang-format off
- nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
- // clang-format on
- "Applys the given patterns to the given op greedily while folding "
- "results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
>From 1b4de1f6a3c9c1cf2035928576757ad6db151423 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Dec 2025 12:50:47 -0800
Subject: [PATCH 22/27] fix empty _mlir_python_support_libs
---
mlir/cmake/modules/AddMLIRPython.cmake | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 7c8b4a5c6cbf0..c2e41c22c30c5 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -454,7 +454,7 @@ function(add_mlir_python_modules name)
# Build extensions.
foreach(sources_target ${_flat_targets})
- _process_target(${name} ${sources_target} ${_mlir_python_support_libs})
+ _process_target(${name} ${sources_target} "${_mlir_python_support_libs}")
endforeach()
# Create an install target.
>From 2dc0c1cd0a3a716fc6b112f716f9e4d8b61f8f21 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Dec 2025 13:24:06 -0800
Subject: [PATCH 23/27] parameteriez add_mlir_python_modules
---
mlir/cmake/modules/AddMLIRPython.cmake | 31 ++++++++++++++++----------
1 file changed, 19 insertions(+), 12 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index c2e41c22c30c5..dcca6383a2f77 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -305,7 +305,7 @@ endfunction()
function(build_nanobind_lib)
cmake_parse_arguments(ARG
""
- "INSTALL_COMPONENT;INSTALL_DESTINATION;OUTPUT_DIRECTORY"
+ "INSTALL_COMPONENT;INSTALL_DESTINATION;OUTPUT_DIRECTORY;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
""
${ARGN})
@@ -314,12 +314,12 @@ function(build_nanobind_lib)
endif()
# nanobind does a string match on the suffix to figure out whether to build
# the lib with free threading...
- set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
+ set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(NB_LIBRARY_TARGET_NAME "${NB_LIBRARY_TARGET_NAME}" PARENT_SCOPE)
nanobind_build_library(${NB_LIBRARY_TARGET_NAME} AS_SYSINCLUDE)
target_compile_definitions(${NB_LIBRARY_TARGET_NAME}
PRIVATE
- NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ NB_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
)
# nanobind configures with LTO for shared build which doesn't work everywhere
# (see https://github.com/llvm/llvm-project/issues/139602).
@@ -365,16 +365,20 @@ endfunction()
function(add_mlir_python_modules name)
cmake_parse_arguments(ARG
""
- "ROOT_PREFIX;INSTALL_PREFIX"
+ "ROOT_PREFIX;INSTALL_PREFIX;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
"COMMON_CAPI_LINK_LIBS;DECLARED_SOURCES"
${ARGN})
# TODO(max): do the same for MLIR_PYTHON_PACKAGE_PREFIX?
- if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ if((NOT ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN) AND MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ set(ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
+ endif()
+ if((NOT ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN) OR ("${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}" STREQUAL ""))
message(WARNING "MLIR_BINDINGS_PYTHON_NB_DOMAIN CMake var is not set - setting to a default `mlir`.\
It is highly recommend to set this to something unique so that your project's Python bindings do not collide with\
- others'. See https://github.com/llvm/llvm-project/pull/171775 for more information.")
- set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
+ others'. You also pass explicitly to `add_mlir_python_modules`.\
+ See https://github.com/llvm/llvm-project/pull/171775 for more information.")
+ set(ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir")
endif()
# This call sets NB_LIBRARY_TARGET_NAME.
@@ -382,6 +386,7 @@ function(add_mlir_python_modules name)
INSTALL_COMPONENT ${name}
INSTALL_DESTINATION "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
)
# Helper to process an individual target.
@@ -407,6 +412,7 @@ function(add_mlir_python_modules name)
INSTALL_COMPONENT ${modules_target}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
LINK_LIBS PRIVATE
${sources_target}
${ARG_COMMON_CAPI_LINK_LIBS}
@@ -433,12 +439,13 @@ function(add_mlir_python_modules name)
if(_source_type STREQUAL "support")
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
# Use a similar mechanism as nanobind to help the runtime loader pick the correct lib.
- set(_module_name "${_module_name}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
+ set(_module_name "${_module_name}-${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(_extension_target "${name}.extension.${_module_name}.dso")
add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
INSTALL_COMPONENT ${name}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
SUPPORT_LIB
LINK_LIBS PRIVATE
LLVMSupport
@@ -842,7 +849,7 @@ endfunction()
function(add_mlir_python_extension libname extname nb_library_target_name)
cmake_parse_arguments(ARG
"SUPPORT_LIB"
- "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY"
+ "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
"SOURCES;LINK_LIBS"
${ARGN})
if(ARG_UNPARSED_ARGUMENTS)
@@ -871,7 +878,7 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
nanobind_link_options(${libname})
target_compile_definitions(${libname}
PRIVATE
- NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ NB_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
MLIR_CAPI_BUILDING_LIBRARY=1
)
if(MSVC)
@@ -879,7 +886,7 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
endif()
else()
nanobind_add_module(${libname}
- NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
FREE_THREADED
NB_SHARED
${ARG_SOURCES}
@@ -888,7 +895,7 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
target_link_libraries(${libname} PRIVATE ${nb_library_target_name})
target_compile_definitions(${libname}
PRIVATE
- MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ MLIR_BINDINGS_PYTHON_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
)
if (NOT MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES
>From 7b521dae4b3af5f4b881ffedacab8b7f30281707 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Dec 2025 13:47:05 -0800
Subject: [PATCH 24/27] address comments
---
mlir/cmake/modules/AddMLIRPython.cmake | 50 ++++++++++++++++++--------
mlir/docs/Bindings/Python.md | 7 ++++
mlir/python/CMakeLists.txt | 2 +-
3 files changed, 43 insertions(+), 16 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index dcca6383a2f77..61c52ded8c14e 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -228,15 +228,15 @@ endfunction()
# aggregate dylib that is linked against.
function(declare_mlir_python_extension name)
cmake_parse_arguments(ARG
- "SUPPORT_LIB"
- "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;SOURCES_TYPE"
+ "_PRIVATE_SUPPORT_LIB"
+ "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT"
"SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS"
${ARGN})
if(NOT ARG_ROOT_DIR)
set(ARG_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
endif()
- if(ARG_SUPPORT_LIB)
+ if(ARG__PRIVATE_SUPPORT_LIB)
set(SOURCES_TYPE "support")
else()
set(SOURCES_TYPE "extension")
@@ -309,6 +309,8 @@ function(build_nanobind_lib)
""
${ARGN})
+ # Only build in free-threaded mode if the Python ABI supports it.
+ # See https://github.com/wjakob/nanobind/blob/4ba51fcf795971c5d603d875ae4184bc0c9bd8e6/cmake/nanobind-config.cmake#L363-L371.
if (NB_ABI MATCHES "[0-9]t")
set(_ft "-ft")
endif()
@@ -321,6 +323,14 @@ function(build_nanobind_lib)
PRIVATE
NB_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
)
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ # nanobind handles this correctly for MacOS by explicitly setting -U for all the necessary Python symbols
+ # (see https://github.com/wjakob/nanobind/blob/master/cmake/darwin-ld-cpython.sym)
+ # but since we set -z,defs in llvm/cmake/modules/HandleLLVMOptions.cmake:340 for all Linux shlibs
+ # we need to negate it here (we could have our own linux-ld-cpython.sym but that would be too much
+ # maintenance).
+ target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "LINKER:-z,undefs")
+ endif()
# nanobind configures with LTO for shared build which doesn't work everywhere
# (see https://github.com/llvm/llvm-project/issues/139602).
if(NOT LLVM_ENABLE_LTO)
@@ -329,13 +339,10 @@ function(build_nanobind_lib)
INTERPROCEDURAL_OPTIMIZATION_MINSIZEREL OFF
)
endif()
- if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
- target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "-Wl,-z,undefs")
- endif()
set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
- # Needed for windows (and don't hurt others).
+ # Needed for windows (and doesn't hurt others).
RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
)
@@ -358,6 +365,11 @@ endfunction()
# for non-relocatable modules or a deeper directory tree for relocatable.
# INSTALL_PREFIX: Prefix into the install tree for installing the package.
# Typically mirrors the path above but without an absolute path.
+# MLIR_BINDINGS_PYTHON_NB_DOMAIN: nanobind (and MLIR) domain within which
+# extensions will be compiled. This determines whether this package
+# will share nanobind types with other bindings packages. Most likely
+# you want this to be unique to your project (and a specific set of bindings,
+# if your project builds several bindings packages).
# DECLARED_SOURCES: List of declared source groups to include. The entire
# DAG of source modules is included.
# COMMON_CAPI_LINK_LIBS: List of dylibs (typically one) to make every
@@ -446,10 +458,9 @@ function(add_mlir_python_modules name)
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
- SUPPORT_LIB
+ _PRIVATE_SUPPORT_LIB
LINK_LIBS PRIVATE
LLVMSupport
- Python::Module
${sources_target}
${ARG_COMMON_CAPI_LINK_LIBS}
)
@@ -726,7 +737,7 @@ function(add_mlir_python_common_capi_library name)
set_target_properties(${name} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
- # Needed for windows (and don't hurt others).
+ # Needed for windows (and doesn't hurt others).
RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
)
@@ -848,7 +859,7 @@ endfunction()
################################################################################
function(add_mlir_python_extension libname extname nb_library_target_name)
cmake_parse_arguments(ARG
- "SUPPORT_LIB"
+ "_PRIVATE_SUPPORT_LIB"
"INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
"SOURCES;LINK_LIBS"
${ARGN})
@@ -865,14 +876,14 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
set(eh_rtti_enable -frtti -fexceptions)
endif ()
- if(ARG_SUPPORT_LIB)
+ if(ARG__PRIVATE_SUPPORT_LIB)
add_library(${libname} SHARED ${ARG_SOURCES})
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# nanobind handles this correctly for MacOS by explicitly setting -U for all the necessary Python symbols
# (see https://github.com/wjakob/nanobind/blob/master/cmake/darwin-ld-cpython.sym)
# but since we set -z,defs in llvm/cmake/modules/HandleLLVMOptions.cmake:340 for all Linux shlibs
# we need to negate it here (we could have our own linux-ld-cpython.sym but that would be too much
- # maintenance - and this shlib is the only one where we do this).
+ # maintenance).
target_link_options(${libname} PRIVATE "LINKER:-z,undefs")
endif()
nanobind_link_options(${libname})
@@ -942,12 +953,21 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
target_compile_options(${libname} PRIVATE ${eh_rtti_enable})
- # Configure the output to match python expectations.
- if (ARG_SUPPORT_LIB)
+ # Quoting CMake:
+ #
+ # "If you use it on normal shared libraries which other targets link against, on some platforms a
+ # linker will insert a full path to the library (as specified at link time) into the dynamic section of the
+ # dependent binary. Therefore, once installed, dynamic loader may eventually fail to locate the library
+ # for the binary."
+ #
+ # So for support libs we do need an SO name but for extensions we do not (they're MODULEs anyway -
+ # i.e., can't be linked against, only loaded).
+ if (ARG__PRIVATE_SUPPORT_LIB)
set(_no_soname OFF)
else ()
set(_no_soname ON)
endif ()
+ # Configure the output to match python expectations.
set_target_properties(
${libname} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${ARG_OUTPUT_DIRECTORY}
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 877ae5170d68c..27661f2880ed2 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -25,6 +25,13 @@
multiple Python implementations, setting this explicitly to the preferred
`python3` executable is strongly recommended.
+* **`MLIR_BINDINGS_PYTHON_NB_DOMAIN`**: `STRING`
+
+ nanobind (and MLIR) domain within which extensions will be compiled.
+ This determines whether this package will share nanobind types with other bindings packages.
+ Most likely you want this to be unique to your project (and a specific set of bindings).
+ This can also be passed explicitly to `add_mlir_python_modules` if your project builds several bindings packages.
+
### Recommended development practices
It is recommended to use a Python virtual environment. Many ways exist for this,
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b22d2ec75b3ba..4a9fb127ee08c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -841,7 +841,7 @@ if(MLIR_INCLUDE_TESTS)
endif()
declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
- SUPPORT_LIB
+ _PRIVATE_SUPPORT_LIB
MODULE_NAME MLIRPythonSupport
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
>From 18fc1e599f13f405560d10ddbcb03f2a4673585b Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Sat, 27 Dec 2025 04:57:19 +0000
Subject: [PATCH 25/27] Reflect rename in bazel file
---
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 785f1e01f5416..35e573bee8a1a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1172,14 +1172,15 @@ PYBIND11_FEATURES = [
filegroup(
name = "MLIRBindingsPythonSourceFiles",
srcs = [
+ "lib/Bindings/Python/Globals.cpp",
"lib/Bindings/Python/IRAffine.cpp",
"lib/Bindings/Python/IRAttributes.cpp",
"lib/Bindings/Python/IRCore.cpp",
"lib/Bindings/Python/IRInterfaces.cpp",
- "lib/Bindings/Python/IRModule.cpp",
"lib/Bindings/Python/IRTypes.cpp",
"lib/Bindings/Python/Pass.cpp",
"lib/Bindings/Python/Rewrite.cpp",
+
],
)
>From 30c603436042e48e20f9d6de9097eb73ba1e21bd Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 27 Dec 2025 10:23:40 -0800
Subject: [PATCH 26/27] address jpienaar comments
---
mlir/cmake/modules/AddMLIRPython.cmake | 8 ++++----
mlir/docs/Bindings/Python.md | 4 ++--
mlir/include/mlir/Bindings/Python/IRCore.h | 13 ++++++++-----
mlir/lib/Bindings/Python/IRCore.cpp | 3 +--
4 files changed, 15 insertions(+), 13 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 61c52ded8c14e..57207ab0e2a18 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -367,9 +367,9 @@ endfunction()
# Typically mirrors the path above but without an absolute path.
# MLIR_BINDINGS_PYTHON_NB_DOMAIN: nanobind (and MLIR) domain within which
# extensions will be compiled. This determines whether this package
-# will share nanobind types with other bindings packages. Most likely
-# you want this to be unique to your project (and a specific set of bindings,
-# if your project builds several bindings packages).
+# will share nanobind types with other bindings packages. Expected to be unique
+# per project (and per specific set of bindings, for projects with multiple
+# bindings packages).
# DECLARED_SOURCES: List of declared source groups to include. The entire
# DAG of source modules is included.
# COMMON_CAPI_LINK_LIBS: List of dylibs (typically one) to make every
@@ -452,7 +452,7 @@ function(add_mlir_python_modules name)
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
# Use a similar mechanism as nanobind to help the runtime loader pick the correct lib.
set(_module_name "${_module_name}-${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
- set(_extension_target "${name}.extension.${_module_name}.dso")
+ set(_extension_target "${name}.extension.${_module_name}.so")
add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
INSTALL_COMPONENT ${name}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 27661f2880ed2..d146037951ca9 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -29,8 +29,8 @@
nanobind (and MLIR) domain within which extensions will be compiled.
This determines whether this package will share nanobind types with other bindings packages.
- Most likely you want this to be unique to your project (and a specific set of bindings).
- This can also be passed explicitly to `add_mlir_python_modules` if your project builds several bindings packages.
+ Expected to be unique per project (and per specific set of bindings, for projects with multiple bindings packages).
+ Can also be passed explicitly to `add_mlir_python_modules`.
### Recommended development practices
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 596ff7828631b..616a9636ec799 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -29,6 +29,7 @@
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ThreadPool.h"
namespace mlir {
@@ -1403,11 +1404,13 @@ createBlock(const nanobind::sequence &pyArgTypes,
argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
}
- if (argTypes.size() != argLocs.size())
- throw nanobind::value_error(("Expected " + Twine(argTypes.size()) +
- " locations, got: " + Twine(argLocs.size()))
- .str()
- .c_str());
+ if (argTypes.size() != argLocs.size()) {
+ throw nanobind::value_error(
+ llvm::formatv("Expected {0} locations, got: {1}", argTypes.size(),
+ argLocs.size())
+ .str()
+ .c_str());
+ }
return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
}
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 26e0128752838..0fe508de38e85 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -924,9 +924,8 @@ nb::object PyOperation::create(std::string_view name,
// Construct the operation.
PyMlirContext::ErrorCapture errors(location.getContext());
MlirOperation operation = mlirOperationCreate(&state);
- if (!operation.ptr) {
+ if (!operation.ptr)
throw MLIRError("Operation creation failed", errors.take());
- }
PyOperationRef created =
PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);
>From ea8c091de923a233972050bb83e10c1e6fd5d5f8 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 27 Dec 2025 11:05:25 -0800
Subject: [PATCH 27/27] move impls
---
mlir/include/mlir/Bindings/Python/IRCore.h | 443 +++-----------------
mlir/lib/Bindings/Python/IRCore.cpp | 454 +++++++++++++++++++++
mlir/lib/Bindings/Python/MainModule.cpp | 20 +
3 files changed, 538 insertions(+), 379 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 616a9636ec799..e52f85f0017c5 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -171,20 +171,14 @@ class MLIR_PYTHON_API_EXPORTED PyThreadContextEntry {
/// Python object owns the C++ thread pool
class MLIR_PYTHON_API_EXPORTED PyThreadPool {
public:
- PyThreadPool() {
- ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
- }
+ PyThreadPool();
PyThreadPool(const PyThreadPool &) = delete;
PyThreadPool(PyThreadPool &&) = delete;
int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
- std::string _mlir_thread_pool_ptr() const {
- std::stringstream ss;
- ss << ownedThreadPool.get();
- return ss.str();
- }
+ std::string _mlir_thread_pool_ptr() const;
private:
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
@@ -209,9 +203,7 @@ class MLIR_PYTHON_API_EXPORTED PyMlirContext {
/// Gets a strong reference to this context, which will ensure it is kept
/// alive for the life of the reference.
- PyMlirContextRef getRef() {
- return PyMlirContextRef(this, nanobind::cast(this));
- }
+ PyMlirContextRef getRef();
/// Gets a capsule wrapping the void* within the MlirContext.
nanobind::object getCapsule();
@@ -652,32 +644,17 @@ class MLIR_PYTHON_API_EXPORTED PyOperation : public PyOperationBase,
/// Detaches the operation from its parent block and updates its state
/// accordingly.
- void detachFromParent() {
- mlirOperationRemoveFromParent(getOperation());
- setDetached();
- parentKeepAlive = nanobind::object();
- }
+ void detachFromParent();
/// Gets the backing operation.
operator MlirOperation() const { return get(); }
- MlirOperation get() const {
- checkValid();
- return operation;
- }
+ MlirOperation get() const;
- PyOperationRef getRef() {
- return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
- }
+ PyOperationRef getRef();
bool isAttached() { return attached; }
- void setAttached(const nanobind::object &parent = nanobind::object()) {
- assert(!attached && "operation already attached");
- attached = true;
- }
- void setDetached() {
- assert(attached && "operation already detached");
- attached = false;
- }
+ void setAttached(const nanobind::object &parent = nanobind::object());
+ void setDetached();
void checkValid() const;
/// Gets the owning block or raises an exception if the operation has no
@@ -802,24 +779,8 @@ class MLIR_PYTHON_API_EXPORTED PyRegion {
/// Wrapper around an MlirAsmState.
class MLIR_PYTHON_API_EXPORTED PyAsmState {
public:
- PyAsmState(MlirValue value, bool useLocalScope) {
- flags = mlirOpPrintingFlagsCreate();
- // The OpPrintingFlags are not exposed Python side, create locally and
- // associate lifetime with the state.
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- state = mlirAsmStateCreateForValue(value, flags);
- }
-
- PyAsmState(PyOperationBase &operation, bool useLocalScope) {
- flags = mlirOpPrintingFlagsCreate();
- // The OpPrintingFlags are not exposed Python side, create locally and
- // associate lifetime with the state.
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- state =
- mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
- }
+ PyAsmState(MlirValue value, bool useLocalScope);
+ PyAsmState(PyOperationBase &operation, bool useLocalScope);
~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
// Delete copy constructors.
PyAsmState(PyAsmState &other) = delete;
@@ -898,6 +859,7 @@ class MLIR_PYTHON_API_EXPORTED PyInsertionPoint {
std::optional<PyOperationRef> refOperation;
PyBlock block;
};
+
/// Wrapper around the generic MlirType.
/// The lifetime of a type is bound by the PyContext that created it.
class MLIR_PYTHON_API_EXPORTED PyType : public BaseContextObject {
@@ -1353,26 +1315,6 @@ struct MLIR_PYTHON_API_EXPORTED MLIRError {
// Utilities.
//------------------------------------------------------------------------------
-/// Helper for creating an @classmethod.
-template <class Func, typename... Args>
-nanobind::object classmethod(Func f, Args... args) {
- nanobind::object cf = nanobind::cpp_function(f, args...);
- return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
-}
-
-inline nanobind::object
-createCustomDialectWrapper(const std::string &dialectNamespace,
- nanobind::object dialectDescriptor) {
- auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
- if (!dialectClass) {
- // Use the base class.
- return nanobind::cast(PyDialect(std::move(dialectDescriptor)));
- }
-
- // Create the custom implementation.
- return (*dialectClass)(std::move(dialectDescriptor));
-}
-
inline MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
@@ -1387,49 +1329,15 @@ inline MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
/// Create a block, using the current location context if no locations are
/// specified.
-inline MlirBlock
-createBlock(const nanobind::sequence &pyArgTypes,
- const std::optional<nanobind::sequence> &pyArgLocs) {
- SmallVector<MlirType> argTypes;
- argTypes.reserve(nanobind::len(pyArgTypes));
- for (const auto &pyType : pyArgTypes)
- argTypes.push_back(nanobind::cast<PyType &>(pyType));
-
- SmallVector<MlirLocation> argLocs;
- if (pyArgLocs) {
- argLocs.reserve(nanobind::len(*pyArgLocs));
- for (const auto &pyLoc : *pyArgLocs)
- argLocs.push_back(nanobind::cast<PyLocation &>(pyLoc));
- } else if (!argTypes.empty()) {
- argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
- }
-
- if (argTypes.size() != argLocs.size()) {
- throw nanobind::value_error(
- llvm::formatv("Expected {0} locations, got: {1}", argTypes.size(),
- argLocs.size())
- .str()
- .c_str());
- }
- return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
-}
+MlirBlock createBlock(const nanobind::sequence &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs);
struct PyAttrBuilderMap {
- static bool dunderContains(const std::string &attributeKind) {
- return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
- }
+ static bool dunderContains(const std::string &attributeKind);
static nanobind::callable
- dunderGetItemNamed(const std::string &attributeKind) {
- auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
- if (!builder)
- throw nanobind::key_error(attributeKind.c_str());
- return *builder;
- }
+ dunderGetItemNamed(const std::string &attributeKind);
static void dunderSetItemNamed(const std::string &attributeKind,
- nanobind::callable func, bool replace) {
- PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
- replace);
- }
+ nanobind::callable func, bool replace);
static void bind(nanobind::module_ &m) {
nanobind::class_<PyAttrBuilderMap>(m, "AttrBuilder")
@@ -1450,14 +1358,6 @@ struct PyAttrBuilderMap {
}
};
-//------------------------------------------------------------------------------
-// PyBlock
-//------------------------------------------------------------------------------
-
-inline nanobind::object PyBlock::getCapsule() {
- return nanobind::steal<nanobind::object>(mlirPythonBlockToCapsule(get()));
-}
-
//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
@@ -1469,14 +1369,7 @@ class MLIR_PYTHON_API_EXPORTED PyRegionIterator {
PyRegionIterator &dunderIter() { return *this; }
- PyRegion dunderNext() {
- operation->checkValid();
- if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
- throw nanobind::stop_iteration();
- }
- MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
- return PyRegion(operation, region);
- }
+ PyRegion dunderNext();
static void bind(nanobind::module_ &m) {
nanobind::class_<PyRegionIterator>(m, "RegionIterator")
@@ -1499,17 +1392,9 @@ class MLIR_PYTHON_API_EXPORTED PyRegionList
static constexpr const char *pyClassName = "RegionSequence";
PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumRegions(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
+ intptr_t length = -1, intptr_t step = 1);
- PyRegionIterator dunderIter() {
- operation->checkValid();
- return PyRegionIterator(operation, startIndex);
- }
+ PyRegionIterator dunderIter();
static void bindDerived(ClassTy &c) {
c.def("__iter__", &PyRegionList::dunderIter,
@@ -1520,19 +1405,11 @@ class MLIR_PYTHON_API_EXPORTED PyRegionList
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyRegionList, PyRegion>;
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumRegions(operation->get());
- }
+ intptr_t getRawNumElements();
- PyRegion getRawElement(intptr_t pos) {
- operation->checkValid();
- return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
- }
+ PyRegion getRawElement(intptr_t pos);
- PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyRegionList(operation, startIndex, length, step);
- }
+ PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) const;
PyOperationRef operation;
};
@@ -1544,16 +1421,7 @@ class MLIR_PYTHON_API_EXPORTED PyBlockIterator {
PyBlockIterator &dunderIter() { return *this; }
- PyBlock dunderNext() {
- operation->checkValid();
- if (mlirBlockIsNull(next)) {
- throw nanobind::stop_iteration();
- }
-
- PyBlock returnBlock(operation, next);
- next = mlirBlockGetNextInRegion(next);
- return returnBlock;
- }
+ PyBlock dunderNext();
static void bind(nanobind::module_ &m) {
nanobind::class_<PyBlockIterator>(m, "BlockIterator")
@@ -1576,49 +1444,14 @@ class MLIR_PYTHON_API_EXPORTED PyBlockList {
PyBlockList(PyOperationRef operation, MlirRegion region)
: operation(std::move(operation)), region(region) {}
- PyBlockIterator dunderIter() {
- operation->checkValid();
- return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
- }
+ PyBlockIterator dunderIter();
- intptr_t dunderLen() {
- operation->checkValid();
- intptr_t count = 0;
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- count += 1;
- block = mlirBlockGetNextInRegion(block);
- }
- return count;
- }
+ intptr_t dunderLen();
- PyBlock dunderGetItem(intptr_t index) {
- operation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nanobind::index_error("attempt to access out of bounds block");
- }
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- if (index == 0) {
- return PyBlock(operation, block);
- }
- block = mlirBlockGetNextInRegion(block);
- index -= 1;
- }
- throw nanobind::index_error("attempt to access out of bounds block");
- }
+ PyBlock dunderGetItem(intptr_t index);
PyBlock appendBlock(const nanobind::args &pyArgTypes,
- const std::optional<nanobind::sequence> &pyArgLocs) {
- operation->checkValid();
- MlirBlock block =
- createBlock(nanobind::cast<nanobind::sequence>(pyArgTypes), pyArgLocs);
- mlirRegionAppendOwnedBlock(region, block);
- return PyBlock(operation, block);
- }
+ const std::optional<nanobind::sequence> &pyArgLocs);
static void bind(nanobind::module_ &m) {
nanobind::class_<PyBlockList>(m, "BlockList")
@@ -1651,17 +1484,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationIterator {
PyOperationIterator &dunderIter() { return *this; }
- nanobind::typed<nanobind::object, PyOpView> dunderNext() {
- parentOperation->checkValid();
- if (mlirOperationIsNull(next)) {
- throw nanobind::stop_iteration();
- }
-
- PyOperationRef returnOperation =
- PyOperation::forOperation(parentOperation->getContext(), next);
- next = mlirOperationGetNextInBlock(next);
- return returnOperation->createOpView();
- }
+ nanobind::typed<nanobind::object, PyOpView> dunderNext();
static void bind(nanobind::module_ &m) {
nanobind::class_<PyOperationIterator>(m, "OperationIterator")
@@ -1691,36 +1514,9 @@ class MLIR_PYTHON_API_EXPORTED PyOperationList {
mlirBlockGetFirstOperation(block));
}
- intptr_t dunderLen() {
- parentOperation->checkValid();
- intptr_t count = 0;
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- count += 1;
- childOp = mlirOperationGetNextInBlock(childOp);
- }
- return count;
- }
+ intptr_t dunderLen();
- nanobind::typed<nanobind::object, PyOpView> dunderGetItem(intptr_t index) {
- parentOperation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nanobind::index_error("attempt to access out of bounds operation");
- }
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- if (index == 0) {
- return PyOperation::forOperation(parentOperation->getContext(), childOp)
- ->createOpView();
- }
- childOp = mlirOperationGetNextInBlock(childOp);
- index -= 1;
- }
- throw nanobind::index_error("attempt to access out of bounds operation");
- }
+ nanobind::typed<nanobind::object, PyOpView> dunderGetItem(intptr_t index);
static void bind(nanobind::module_ &m) {
nanobind::class_<PyOperationList>(m, "OperationList")
@@ -1741,14 +1537,9 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
- nanobind::typed<nanobind::object, PyOpView> getOwner() {
- MlirOperation owner = mlirOpOperandGetOwner(opOperand);
- PyMlirContextRef context =
- PyMlirContext::forContext(mlirOperationGetContext(owner));
- return PyOperation::forOperation(context, owner)->createOpView();
- }
+ nanobind::typed<nanobind::object, PyOpView> getOwner() const;
- size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
+ size_t getOperandNumber() const;
static void bind(nanobind::module_ &m) {
nanobind::class_<PyOpOperand>(m, "OpOperand")
@@ -1768,14 +1559,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperandIterator {
PyOpOperandIterator &dunderIter() { return *this; }
- PyOpOperand dunderNext() {
- if (mlirOpOperandIsNull(opOperand))
- throw nanobind::stop_iteration();
-
- PyOpOperand returnOpOperand(opOperand);
- opOperand = mlirOpOperandGetNextUse(opOperand);
- return returnOpOperand;
- }
+ PyOpOperand dunderNext();
static void bind(nanobind::module_ &m) {
nanobind::class_<PyOpOperandIterator>(m, "OpOperandIterator")
@@ -1931,19 +1715,12 @@ class MLIR_PYTHON_API_EXPORTED PyOpResultList
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyOpResultList, PyOpResult>;
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumResults(operation->get());
- }
+ intptr_t getRawNumElements();
- PyOpResult getRawElement(intptr_t index) {
- PyValue value(operation, mlirOperationGetResult(operation->get(), index));
- return PyOpResult(value);
- }
+ PyOpResult getRawElement(intptr_t index);
- PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpResultList(operation, startIndex, length, step);
- }
+ PyOpResultList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
PyOperationRef operation;
};
@@ -2017,22 +1794,14 @@ class MLIR_PYTHON_API_EXPORTED PyBlockArgumentList
friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
/// Returns the number of arguments in the list.
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirBlockGetNumArguments(block);
- }
+ intptr_t getRawNumElements();
/// Returns `pos`-the element in the list.
- PyBlockArgument getRawElement(intptr_t pos) {
- MlirValue argument = mlirBlockGetArgument(block, pos);
- return PyBlockArgument(operation, argument);
- }
+ PyBlockArgument getRawElement(intptr_t pos) const;
/// Returns a sublist of this list.
PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockArgumentList(operation, block, startIndex, length, step);
- }
+ intptr_t step) const;
PyOperationRef operation;
MlirBlock block;
@@ -2056,10 +1825,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperandList
step),
operation(operation) {}
- void dunderSetItem(intptr_t index, PyValue value) {
- index = wrapIndex(index);
- mlirOperationSetOperand(operation->get(), index, value.get());
- }
+ void dunderSetItem(intptr_t index, PyValue value);
static void bindDerived(ClassTy &c) {
c.def("__setitem__", &PyOpOperandList::dunderSetItem,
@@ -2071,28 +1837,12 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperandList
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyOpOperandList, PyValue>;
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumOperands(operation->get());
- }
+ intptr_t getRawNumElements();
- PyValue getRawElement(intptr_t pos) {
- MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
- MlirOperation owner;
- if (mlirValueIsAOpResult(operand))
- owner = mlirOpResultGetOwner(operand);
- else if (mlirValueIsABlockArgument(operand))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
- else
- assert(false && "Value must be an block arg or op result.");
- PyOperationRef pyOwner =
- PyOperation::forOperation(operation->getContext(), owner);
- return PyValue(pyOwner, operand);
- }
+ PyValue getRawElement(intptr_t pos);
- PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpOperandList(operation, startIndex, length, step);
- }
+ PyOpOperandList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
PyOperationRef operation;
};
@@ -2114,10 +1864,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpSuccessors
step),
operation(operation) {}
- void dunderSetItem(intptr_t index, PyBlock block) {
- index = wrapIndex(index);
- mlirOperationSetSuccessor(operation->get(), index, block.get());
- }
+ void dunderSetItem(intptr_t index, PyBlock block);
static void bindDerived(ClassTy &c) {
c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nanobind::arg("index"),
@@ -2129,19 +1876,12 @@ class MLIR_PYTHON_API_EXPORTED PyOpSuccessors
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyOpSuccessors, PyBlock>;
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumSuccessors(operation->get());
- }
+ intptr_t getRawNumElements();
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
- return PyBlock(operation, block);
- }
+ PyBlock getRawElement(intptr_t pos);
- PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpSuccessors(operation, startIndex, length, step);
- }
+ PyOpSuccessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
PyOperationRef operation;
};
@@ -2168,19 +1908,12 @@ class MLIR_PYTHON_API_EXPORTED PyBlockSuccessors
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyBlockSuccessors, PyBlock>;
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumSuccessors(block.get());
- }
+ intptr_t getRawNumElements();
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
+ PyBlock getRawElement(intptr_t pos);
- PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyBlockSuccessors(block, operation, startIndex, length, step);
- }
+ PyBlockSuccessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
PyOperationRef operation;
PyBlock block;
@@ -2211,20 +1944,12 @@ class MLIR_PYTHON_API_EXPORTED PyBlockPredecessors
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyBlockPredecessors, PyBlock>;
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumPredecessors(block.get());
- }
+ intptr_t getRawNumElements();
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
+ PyBlock getRawElement(intptr_t pos);
PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockPredecessors(block, operation, startIndex, length, step);
- }
+ intptr_t step) const;
PyOperationRef operation;
PyBlock block;
@@ -2238,61 +1963,21 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
: operation(std::move(operation)) {}
nanobind::typed<nanobind::object, PyAttribute>
- dunderGetItemNamed(const std::string &name) {
- MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (mlirAttributeIsNull(attr)) {
- throw nanobind::key_error("attempt to access a non-existent attribute");
- }
- return PyAttribute(operation->getContext(), attr).maybeDownCast();
- }
+ dunderGetItemNamed(const std::string &name);
- PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0 || index >= dunderLen()) {
- throw nanobind::index_error("attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr =
- mlirOperationGetAttribute(operation->get(), index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data,
- mlirIdentifierStr(namedAttr.name).length));
- }
+ PyNamedAttribute dunderGetItemIndexed(intptr_t index);
- void dunderSetItem(const std::string &name, const PyAttribute &attr) {
- mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
- attr);
- }
+ void dunderSetItem(const std::string &name, const PyAttribute &attr);
- void dunderDelItem(const std::string &name) {
- int removed = mlirOperationRemoveAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (!removed)
- throw nanobind::key_error("attempt to delete a non-existent attribute");
- }
+ void dunderDelItem(const std::string &name);
- intptr_t dunderLen() {
- return mlirOperationGetNumAttributes(operation->get());
- }
+ intptr_t dunderLen();
- bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
- operation->get(), toMlirStringRef(name)));
- }
+ bool dunderContains(const std::string &name);
static void
forEachAttr(MlirOperation op,
- llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
- intptr_t n = mlirOperationGetNumAttributes(op);
- for (intptr_t i = 0; i < n; ++i) {
- MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
- MlirStringRef name = mlirIdentifierStr(na.name);
- fn(name, na.attribute);
- }
- }
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn);
static void bind(nanobind::module_ &m) {
nanobind::class_<PyOpAttributeMap>(m, "OpAttributeMap")
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0fe508de38e85..2ea07a6c9adec 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -39,6 +39,20 @@ using llvm::Twine;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+//------------------------------------------------------------------------------
+// PyThreadPool
+//------------------------------------------------------------------------------
+
+PyThreadPool::PyThreadPool() {
+ ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+}
+
+std::string PyThreadPool::_mlir_thread_pool_ptr() const {
+ std::stringstream ss;
+ ss << ownedThreadPool.get();
+ return ss.str();
+}
+
//------------------------------------------------------------------------------
// PyMlirContext
//------------------------------------------------------------------------------
@@ -62,6 +76,9 @@ PyMlirContext::~PyMlirContext() {
mlirContextDestroy(context);
}
+PyMlirContextRef PyMlirContext::getRef() {
+ return PyMlirContextRef(this, nanobind::cast(this));
+}
nb::object PyMlirContext::getCapsule() {
return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
}
@@ -598,6 +615,31 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
return PyOperation::createDetached(std::move(contextRef), op);
}
+void PyOperation::detachFromParent() {
+ mlirOperationRemoveFromParent(getOperation());
+ setDetached();
+ parentKeepAlive = nanobind::object();
+}
+
+MlirOperation PyOperation::get() const {
+ checkValid();
+ return operation;
+}
+
+PyOperationRef PyOperation::getRef() {
+ return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
+}
+
+void PyOperation::setAttached(const nanobind::object &parent) {
+ assert(!attached && "operation already attached");
+ attached = true;
+}
+
+void PyOperation::setDetached() {
+ assert(attached && "operation already detached");
+ attached = false;
+}
+
void PyOperation::checkValid() const {
if (!valid) {
throw std::runtime_error("the operation has been invalidated");
@@ -1292,6 +1334,36 @@ PyOpView::PyOpView(const nb::object &operationObject)
: operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
operationObject(operation.getRef().getObject()) {}
+//------------------------------------------------------------------------------
+// PyBlock
+//------------------------------------------------------------------------------
+
+nanobind::object PyBlock::getCapsule() {
+ return nanobind::steal<nanobind::object>(mlirPythonBlockToCapsule(get()));
+}
+
+//------------------------------------------------------------------------------
+// PyAsmState
+//------------------------------------------------------------------------------
+
+PyAsmState::PyAsmState(MlirValue value, bool useLocalScope) {
+ flags = mlirOpPrintingFlagsCreate();
+ // The OpPrintingFlags are not exposed Python side, create locally and
+ // associate lifetime with the state.
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ state = mlirAsmStateCreateForValue(value, flags);
+}
+
+PyAsmState::PyAsmState(PyOperationBase &operation, bool useLocalScope) {
+ flags = mlirOpPrintingFlagsCreate();
+ // The OpPrintingFlags are not exposed Python side, create locally and
+ // associate lifetime with the state.
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
+}
+
//------------------------------------------------------------------------------
// PyInsertionPoint.
//------------------------------------------------------------------------------
@@ -1672,6 +1744,388 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
throw std::runtime_error(message);
}
}
+
+MlirBlock createBlock(const nanobind::sequence &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs) {
+ SmallVector<MlirType> argTypes;
+ argTypes.reserve(nanobind::len(pyArgTypes));
+ for (const auto &pyType : pyArgTypes)
+ argTypes.push_back(nanobind::cast<PyType &>(pyType));
+
+ SmallVector<MlirLocation> argLocs;
+ if (pyArgLocs) {
+ argLocs.reserve(nanobind::len(*pyArgLocs));
+ for (const auto &pyLoc : *pyArgLocs)
+ argLocs.push_back(nanobind::cast<PyLocation &>(pyLoc));
+ } else if (!argTypes.empty()) {
+ argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
+ }
+
+ if (argTypes.size() != argLocs.size()) {
+ throw nanobind::value_error(("Expected " + Twine(argTypes.size()) +
+ " locations, got: " + Twine(argLocs.size()))
+ .str()
+ .c_str());
+ }
+ return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
+}
+
+//------------------------------------------------------------------------------
+// PyAttrBuilderMap
+//------------------------------------------------------------------------------
+
+bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) {
+ return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
+}
+
+nanobind::callable
+PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
+ auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
+ if (!builder)
+ throw nanobind::key_error(attributeKind.c_str());
+ return *builder;
+}
+
+void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
+ nanobind::callable func,
+ bool replace) {
+ PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
+ replace);
+}
+
+//------------------------------------------------------------------------------
+// Collections.
+//------------------------------------------------------------------------------
+
+PyRegion PyRegionIterator::dunderNext() {
+ operation->checkValid();
+ if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+ throw nanobind::stop_iteration();
+ }
+ MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
+ return PyRegion(operation, region);
+}
+
+PyRegionList::PyRegionList(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumRegions(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+PyRegionIterator PyRegionList::dunderIter() {
+ operation->checkValid();
+ return PyRegionIterator(operation, startIndex);
+}
+
+intptr_t PyRegionList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumRegions(operation->get());
+}
+
+PyRegion PyRegionList::getRawElement(intptr_t pos) {
+ operation->checkValid();
+ return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
+}
+
+PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyRegionList(operation, startIndex, length, step);
+}
+
+PyBlock PyBlockIterator::dunderNext() {
+ operation->checkValid();
+ if (mlirBlockIsNull(next)) {
+ throw nanobind::stop_iteration();
+ }
+
+ PyBlock returnBlock(operation, next);
+ next = mlirBlockGetNextInRegion(next);
+ return returnBlock;
+}
+
+PyBlockIterator PyBlockList::dunderIter() {
+ operation->checkValid();
+ return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+}
+
+intptr_t PyBlockList::dunderLen() {
+ operation->checkValid();
+ intptr_t count = 0;
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ count += 1;
+ block = mlirBlockGetNextInRegion(block);
+ }
+ return count;
+}
+
+PyBlock PyBlockList::dunderGetItem(intptr_t index) {
+ operation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nanobind::index_error("attempt to access out of bounds block");
+ }
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ if (index == 0) {
+ return PyBlock(operation, block);
+ }
+ block = mlirBlockGetNextInRegion(block);
+ index -= 1;
+ }
+ throw nanobind::index_error("attempt to access out of bounds block");
+}
+
+PyBlock
+PyBlockList::appendBlock(const nanobind::args &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs) {
+ operation->checkValid();
+ MlirBlock block =
+ createBlock(nanobind::cast<nanobind::sequence>(pyArgTypes), pyArgLocs);
+ mlirRegionAppendOwnedBlock(region, block);
+ return PyBlock(operation, block);
+}
+
+nanobind::typed<nanobind::object, PyOpView> PyOperationIterator::dunderNext() {
+ parentOperation->checkValid();
+ if (mlirOperationIsNull(next)) {
+ throw nanobind::stop_iteration();
+ }
+
+ PyOperationRef returnOperation =
+ PyOperation::forOperation(parentOperation->getContext(), next);
+ next = mlirOperationGetNextInBlock(next);
+ return returnOperation->createOpView();
+}
+
+intptr_t PyOperationList::dunderLen() {
+ parentOperation->checkValid();
+ intptr_t count = 0;
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ count += 1;
+ childOp = mlirOperationGetNextInBlock(childOp);
+ }
+ return count;
+}
+
+nanobind::typed<nanobind::object, PyOpView>
+PyOperationList::dunderGetItem(intptr_t index) {
+ parentOperation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nanobind::index_error("attempt to access out of bounds operation");
+ }
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ if (index == 0) {
+ return PyOperation::forOperation(parentOperation->getContext(), childOp)
+ ->createOpView();
+ }
+ childOp = mlirOperationGetNextInBlock(childOp);
+ index -= 1;
+ }
+ throw nanobind::index_error("attempt to access out of bounds operation");
+}
+
+nanobind::typed<nanobind::object, PyOpView> PyOpOperand::getOwner() const {
+ MlirOperation owner = mlirOpOperandGetOwner(opOperand);
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(owner));
+ return PyOperation::forOperation(context, owner)->createOpView();
+}
+
+size_t PyOpOperand::getOperandNumber() const {
+ return mlirOpOperandGetOperandNumber(opOperand);
+}
+
+PyOpOperand PyOpOperandIterator::dunderNext() {
+ if (mlirOpOperandIsNull(opOperand))
+ throw nanobind::stop_iteration();
+
+ PyOpOperand returnOpOperand(opOperand);
+ opOperand = mlirOpOperandGetNextUse(opOperand);
+ return returnOpOperand;
+}
+
+//------------------------------------------------------------------------------
+// PyConcreteValue
+//------------------------------------------------------------------------------
+
+intptr_t PyOpResultList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumResults(operation->get());
+}
+
+PyOpResult PyOpResultList::getRawElement(intptr_t index) {
+ PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+ return PyOpResult(value);
+}
+
+PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyOpResultList(operation, startIndex, length, step);
+}
+
+intptr_t PyBlockArgumentList::getRawNumElements() {
+ operation->checkValid();
+ return mlirBlockGetNumArguments(block);
+}
+
+PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const {
+ MlirValue argument = mlirBlockGetArgument(block, pos);
+ return PyBlockArgument(operation, argument);
+}
+
+PyBlockArgumentList PyBlockArgumentList::slice(intptr_t startIndex,
+ intptr_t length,
+ intptr_t step) const {
+ return PyBlockArgumentList(operation, block, startIndex, length, step);
+}
+
+void PyOpOperandList::dunderSetItem(intptr_t index, PyValue value) {
+ index = wrapIndex(index);
+ mlirOperationSetOperand(operation->get(), index, value.get());
+}
+
+intptr_t PyOpOperandList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumOperands(operation->get());
+}
+
+PyValue PyOpOperandList::getRawElement(intptr_t pos) {
+ MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(operand))
+ owner = mlirOpResultGetOwner(operand);
+ else if (mlirValueIsABlockArgument(operand))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
+ else
+ assert(false && "Value must be an block arg or op result.");
+ PyOperationRef pyOwner =
+ PyOperation::forOperation(operation->getContext(), owner);
+ return PyValue(pyOwner, operand);
+}
+
+PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyOpOperandList(operation, startIndex, length, step);
+}
+
+void PyOpSuccessors::dunderSetItem(intptr_t index, PyBlock block) {
+ index = wrapIndex(index);
+ mlirOperationSetSuccessor(operation->get(), index, block.get());
+}
+
+intptr_t PyOpSuccessors::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumSuccessors(operation->get());
+}
+
+PyBlock PyOpSuccessors::getRawElement(intptr_t pos) {
+ MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
+ return PyBlock(operation, block);
+}
+
+PyOpSuccessors PyOpSuccessors::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyOpSuccessors(operation, startIndex, length, step);
+}
+
+intptr_t PyBlockSuccessors::getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumSuccessors(block.get());
+}
+
+PyBlock PyBlockSuccessors::getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+}
+
+PyBlockSuccessors PyBlockSuccessors::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyBlockSuccessors(block, operation, startIndex, length, step);
+}
+
+intptr_t PyBlockPredecessors::getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumPredecessors(block.get());
+}
+
+PyBlock PyBlockPredecessors::getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+}
+
+PyBlockPredecessors PyBlockPredecessors::slice(intptr_t startIndex,
+ intptr_t length,
+ intptr_t step) const {
+ return PyBlockPredecessors(block, operation, startIndex, length, step);
+}
+
+nanobind::typed<nanobind::object, PyAttribute>
+PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
+ MlirAttribute attr =
+ mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr)) {
+ throw nanobind::key_error("attempt to access a non-existent attribute");
+ }
+ return PyAttribute(operation->getContext(), attr).maybeDownCast();
+}
+
+PyNamedAttribute PyOpAttributeMap::dunderGetItemIndexed(intptr_t index) {
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0 || index >= dunderLen()) {
+ throw nanobind::index_error("attempt to access out of bounds attribute");
+ }
+ MlirNamedAttribute namedAttr =
+ mlirOperationGetAttribute(operation->get(), index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data,
+ mlirIdentifierStr(namedAttr.name).length));
+}
+
+void PyOpAttributeMap::dunderSetItem(const std::string &name,
+ const PyAttribute &attr) {
+ mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
+ attr);
+}
+
+void PyOpAttributeMap::dunderDelItem(const std::string &name) {
+ int removed = mlirOperationRemoveAttributeByName(operation->get(),
+ toMlirStringRef(name));
+ if (!removed)
+ throw nanobind::key_error("attempt to delete a non-existent attribute");
+}
+
+intptr_t PyOpAttributeMap::dunderLen() {
+ return mlirOperationGetNumAttributes(operation->get());
+}
+
+bool PyOpAttributeMap::dunderContains(const std::string &name) {
+ return !mlirAttributeIsNull(
+ mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)));
+}
+
+void PyOpAttributeMap::forEachAttr(
+ MlirOperation op,
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
+ intptr_t n = mlirOperationGetNumAttributes(op);
+ for (intptr_t i = 0; i < n; ++i) {
+ MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
+ MlirStringRef name = mlirIdentifierStr(na.name);
+ fn(name, na.attribute);
+ }
+}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 071f106da04bb..79c8e36609d76 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -235,6 +235,26 @@ maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
return {ref, mlirLoc};
}
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+nanobind::object classmethod(Func f, Args... args) {
+ nanobind::object cf = nanobind::cpp_function(f, args...);
+ return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
+}
+
+nanobind::object
+createCustomDialectWrapper(const std::string &dialectNamespace,
+ nanobind::object dialectDescriptor) {
+ auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
+ if (!dialectClass) {
+ // Use the base class.
+ return nanobind::cast(PyDialect(std::move(dialectDescriptor)));
+ }
+
+ // Create the custom implementation.
+ return (*dialectClass)(std::move(dialectDescriptor));
+}
} // namespace
//------------------------------------------------------------------------------
More information about the llvm-commits
mailing list