[Mlir-commits] [mlir] [mlir][Python] create MLIRPythonSupport (PR #171775)
Maksim Levental
llvmlistbot at llvm.org
Thu Dec 11 15:53:46 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/171775
>From c43c604fb20004b35d6c38880666c8fdae620191 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 10 Dec 2025 23:57:13 -0800
Subject: [PATCH 1/3] [mlir][Python] create MLIRPythonSupport
---
mlir/python/CMakeLists.txt | 65 ++++++++++++++++++++++++++++++--------
1 file changed, 52 insertions(+), 13 deletions(-)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2acb6ee6cfda5..02124d12aea40 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -3,6 +3,8 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `MLIR_PYTHON_PACKAGE_PREFIX.`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
+set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
+set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
################################################################################
# Structural groupings.
@@ -524,7 +526,6 @@ declare_mlir_dialect_python_bindings(
# dependencies.
################################################################################
-set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
declare_mlir_python_extension(MLIRPythonExtension.Core
MODULE_NAME _mlir
ADD_TO_PARENT MLIRPythonSources.Core
@@ -532,20 +533,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
MainModule.cpp
- IRAffine.cpp
- IRAttributes.cpp
- IRCore.cpp
- IRInterfaces.cpp
- IRModule.cpp
- IRTypes.cpp
Pass.cpp
Rewrite.cpp
# Headers must be included explicitly so they are installed.
- Globals.h
- IRModule.h
Pass.h
- NanobindUtils.h
Rewrite.h
PRIVATE_LINK_LIBS
LLVMSupport
@@ -768,8 +760,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectSMT.cpp
- # Headers must be included explicitly so they are installed.
- NanobindUtils.h
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
@@ -876,7 +866,6 @@ endif()
# once ready.
################################################################################
-set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
add_mlir_python_common_capi_library(MLIRPythonCAPI
INSTALL_COMPONENT MLIRPythonModules
INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
@@ -1013,3 +1002,53 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}")
endif()
endif()
+
+get_property(NB_LIBRARY_TARGET_NAME TARGET MLIRPythonModules.extension._mlir.dso PROPERTY LINK_LIBRARIES)
+list(GET NB_LIBRARY_TARGET_NAME 0 NB_LIBRARY_TARGET_NAME)
+add_mlir_library_install(${NB_LIBRARY_TARGET_NAME})
+add_mlir_library(MLIRPythonSupport
+ ${PYTHON_SOURCE_DIR}/Globals.cpp
+ ${PYTHON_SOURCE_DIR}/IRAffine.cpp
+ ${PYTHON_SOURCE_DIR}/IRAttributes.cpp
+ ${PYTHON_SOURCE_DIR}/IRCore.cpp
+ ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp
+ ${PYTHON_SOURCE_DIR}/IRTypes.cpp
+ EXCLUDE_FROM_LIBMLIR
+ SHARED
+ LINK_COMPONENTS
+ Support
+ LINK_LIBS
+ ${NB_LIBRARY_TARGET_NAME}
+ MLIRCAPIIR
+)
+target_link_libraries(MLIRPythonSupport PUBLIC ${NB_LIBRARY_TARGET_NAME})
+nanobind_link_options(MLIRPythonSupport)
+set_target_properties(MLIRPythonSupport PROPERTIES
+ LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ # Needed for windows (and doesn't hurt others).
+ RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+)
+set(eh_rtti_enable)
+if(MSVC)
+ set(eh_rtti_enable /EHsc /GR)
+elseif(LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
+ set(eh_rtti_enable -frtti -fexceptions)
+endif()
+target_compile_options(MLIRPythonSupport PRIVATE ${eh_rtti_enable})
+if(APPLE)
+ # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind
+ # doesn't declare this API as undefined in its linker flags. So we need to declare it as such
+ # for downstream users that do not do something like `-undefined dynamic_lookup`.
+ # Same for the rest.
+ target_link_options(MLIRPythonSupport PUBLIC
+ "LINKER:-U,_PyClassMethod_New"
+ "LINKER:-U,_PyCode_Addr2Location"
+ "LINKER:-U,_PyFrame_GetLasti"
+ )
+endif()
+target_link_libraries(
+ MLIRPythonModules.extension._mlir.dso
+ PUBLIC MLIRPythonSupport)
+
>From d728dfc726cf107a709d25e889c3076f443f4790 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 11 Dec 2025 13:19:23 -0800
Subject: [PATCH 2/3] kind of working
---
mlir/cmake/modules/AddMLIRPython.cmake | 1 +
.../mlir}/Bindings/Python/Globals.h | 48 +-
.../mlir/Bindings/Python/IRCore.h} | 1025 ++++-
.../mlir}/Bindings/Python/NanobindUtils.h | 0
mlir/lib/Bindings/Python/DialectSMT.cpp | 2 +-
.../Python/{IRModule.cpp => Globals.cpp} | 14 +-
mlir/lib/Bindings/Python/IRAffine.cpp | 10 +-
mlir/lib/Bindings/Python/IRAttributes.cpp | 21 +-
mlir/lib/Bindings/Python/IRCore.cpp | 3309 +----------------
mlir/lib/Bindings/Python/IRInterfaces.cpp | 2 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 22 +-
mlir/lib/Bindings/Python/MainModule.cpp | 2277 +++++++++++-
mlir/lib/Bindings/Python/Pass.cpp | 4 +-
mlir/lib/Bindings/Python/Pass.h | 2 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 2 +-
mlir/lib/Bindings/Python/Rewrite.h | 2 +-
mlir/python/CMakeLists.txt | 19 +-
17 files changed, 3428 insertions(+), 3332 deletions(-)
rename mlir/{lib => include/mlir}/Bindings/Python/Globals.h (82%)
rename mlir/{lib/Bindings/Python/IRModule.h => include/mlir/Bindings/Python/IRCore.h} (57%)
rename mlir/{lib => include/mlir}/Bindings/Python/NanobindUtils.h (100%)
rename mlir/lib/Bindings/Python/{IRModule.cpp => Globals.cpp} (97%)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 8196e2a2a3321..b725a00c183ca 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -780,6 +780,7 @@ function(add_mlir_python_extension libname extname)
nanobind_add_module(${libname}
NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
FREE_THREADED
+ NB_SHARED
${ARG_SOURCES}
)
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
similarity index 82%
rename from mlir/lib/Bindings/Python/Globals.h
rename to mlir/include/mlir/Bindings/Python/Globals.h
index 1e81f53e465ac..fea7a201453ce 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -15,10 +15,12 @@
#include <unordered_set>
#include <vector>
-#include "NanobindUtils.h"
+#include "mlir-c/Debug.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir/CAPI/Support.h"
+
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
@@ -200,6 +202,50 @@ class PyGlobals {
TypeIDAllocator typeIDAllocator;
};
+/// Wrapper for the global LLVM debugging flag.
+struct PyGlobalDebugFlag {
+ static void set(nanobind::object &o, bool enable) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+ }
+
+ static bool get(const nanobind::object &) {
+ nanobind::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+ }
+
+ static void bind(nanobind::module_ &m) {
+ // Debug flags.
+ nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+ .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
+ .def_static(
+ "set_types",
+ [](const std::string &type) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugType(type.c_str());
+ },
+ nanobind::arg("types"),
+ "Sets specific debug types to be produced by LLVM.")
+ .def_static(
+ "set_types",
+ [](const std::vector<std::string> &types) {
+ std::vector<const char *> pointers;
+ pointers.reserve(types.size());
+ for (const std::string &str : types)
+ pointers.push_back(str.c_str());
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
+ },
+ nanobind::arg("types"),
+ "Sets multiple specific debug types to be produced by LLVM.");
+ }
+
+private:
+ static nanobind::ft_mutex mutex;
+};
+
+
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/include/mlir/Bindings/Python/IRCore.h
similarity index 57%
rename from mlir/lib/Bindings/Python/IRModule.h
rename to mlir/include/mlir/Bindings/Python/IRCore.h
index e706be3b4d32a..488196ea42e44 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1,4 +1,4 @@
-//===- IRModules.h - IR Submodules of pybind module -----------------------===//
+//===- IRCore.h - IR helpers of python bindings ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,8 +7,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===----------------------------------------------------------------------===//
-#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
-#define MLIR_BINDINGS_PYTHON_IRMODULES_H
+#ifndef MLIR_BINDINGS_PYTHON_IRCORE_H
+#define MLIR_BINDINGS_PYTHON_IRCORE_H
#include <optional>
#include <sstream>
@@ -20,12 +20,14 @@
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ThreadPool.h"
@@ -1323,12 +1325,1017 @@ struct MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
-void populateIRAffine(nanobind::module_ &m);
-void populateIRAttributes(nanobind::module_ &m);
-void populateIRCore(nanobind::module_ &m);
-void populateIRInterfaces(nanobind::module_ &m);
-void populateIRTypes(nanobind::module_ &m);
+//------------------------------------------------------------------------------
+// Utilities.
+//------------------------------------------------------------------------------
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+static nanobind::object classmethod(Func f, Args... args) {
+ nanobind::object cf = nanobind::cpp_function(f, args...);
+ return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
+}
+
+static nanobind::object
+createCustomDialectWrapper(const std::string &dialectNamespace,
+ nanobind::object dialectDescriptor) {
+ auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
+ if (!dialectClass) {
+ // Use the base class.
+ return nanobind::cast(PyDialect(std::move(dialectDescriptor)));
+ }
+
+ // Create the custom implementation.
+ return (*dialectClass)(std::move(dialectDescriptor));
+}
+
+static MlirStringRef toMlirStringRef(const std::string &s) {
+ return mlirStringRefCreate(s.data(), s.size());
+}
+
+static MlirStringRef toMlirStringRef(std::string_view s) {
+ return mlirStringRefCreate(s.data(), s.size());
+}
+
+static MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
+ return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
+}
+
+/// Create a block, using the current location context if no locations are
+/// specified.
+static MlirBlock
+createBlock(const nanobind::sequence &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs) {
+ SmallVector<MlirType> argTypes;
+ argTypes.reserve(nanobind::len(pyArgTypes));
+ for (const auto &pyType : pyArgTypes)
+ argTypes.push_back(nanobind::cast<PyType &>(pyType));
+
+ SmallVector<MlirLocation> argLocs;
+ if (pyArgLocs) {
+ argLocs.reserve(nanobind::len(*pyArgLocs));
+ for (const auto &pyLoc : *pyArgLocs)
+ argLocs.push_back(nanobind::cast<PyLocation &>(pyLoc));
+ } else if (!argTypes.empty()) {
+ argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
+ }
+
+ if (argTypes.size() != argLocs.size())
+ throw nanobind::value_error(("Expected " + Twine(argTypes.size()) +
+ " locations, got: " + Twine(argLocs.size()))
+ .str()
+ .c_str());
+ return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
+}
+
+struct PyAttrBuilderMap {
+ static bool dunderContains(const std::string &attributeKind) {
+ return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
+ }
+ static nanobind::callable
+ dunderGetItemNamed(const std::string &attributeKind) {
+ auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
+ if (!builder)
+ throw nanobind::key_error(attributeKind.c_str());
+ return *builder;
+ }
+ static void dunderSetItemNamed(const std::string &attributeKind,
+ nanobind::callable func, bool replace) {
+ PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
+ replace);
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyAttrBuilderMap>(m, "AttrBuilder")
+ .def_static("contains", &PyAttrBuilderMap::dunderContains,
+ nanobind::arg("attribute_kind"),
+ "Checks whether an attribute builder is registered for the "
+ "given attribute kind.")
+ .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
+ nanobind::arg("attribute_kind"),
+ "Gets the registered attribute builder for the given "
+ "attribute kind.")
+ .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
+ nanobind::arg("attribute_kind"),
+ nanobind::arg("attr_builder"),
+ nanobind::arg("replace") = false,
+ "Register an attribute builder for building MLIR "
+ "attributes from Python values.");
+ }
+};
+
+//------------------------------------------------------------------------------
+// PyBlock
+//------------------------------------------------------------------------------
+
+inline nanobind::object PyBlock::getCapsule() {
+ return nanobind::steal<nanobind::object>(mlirPythonBlockToCapsule(get()));
+}
+
+//------------------------------------------------------------------------------
+// Collections.
+//------------------------------------------------------------------------------
+
+class PyRegionIterator {
+public:
+ PyRegionIterator(PyOperationRef operation, int nextIndex)
+ : operation(std::move(operation)), nextIndex(nextIndex) {}
+
+ PyRegionIterator &dunderIter() { return *this; }
+
+ PyRegion dunderNext() {
+ operation->checkValid();
+ if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+ throw nanobind::stop_iteration();
+ }
+ MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
+ return PyRegion(operation, region);
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyRegionIterator>(m, "RegionIterator")
+ .def("__iter__", &PyRegionIterator::dunderIter,
+ "Returns an iterator over the regions in the operation.")
+ .def("__next__", &PyRegionIterator::dunderNext,
+ "Returns the next region in the iteration.");
+ }
+
+private:
+ PyOperationRef operation;
+ intptr_t nextIndex = 0;
+};
+
+/// Regions of an op are fixed length and indexed numerically so are represented
+/// with a sequence-like container.
+class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
+public:
+ static constexpr const char *pyClassName = "RegionSequence";
+
+ PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumRegions(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+ PyRegionIterator dunderIter() {
+ operation->checkValid();
+ return PyRegionIterator(operation, startIndex);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__iter__", &PyRegionList::dunderIter,
+ "Returns an iterator over the regions in the sequence.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyRegionList, PyRegion>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumRegions(operation->get());
+ }
+
+ PyRegion getRawElement(intptr_t pos) {
+ operation->checkValid();
+ return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
+ }
+
+ PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyRegionList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+class PyBlockIterator {
+public:
+ PyBlockIterator(PyOperationRef operation, MlirBlock next)
+ : operation(std::move(operation)), next(next) {}
+
+ PyBlockIterator &dunderIter() { return *this; }
+
+ PyBlock dunderNext() {
+ operation->checkValid();
+ if (mlirBlockIsNull(next)) {
+ throw nanobind::stop_iteration();
+ }
+
+ PyBlock returnBlock(operation, next);
+ next = mlirBlockGetNextInRegion(next);
+ return returnBlock;
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyBlockIterator>(m, "BlockIterator")
+ .def("__iter__", &PyBlockIterator::dunderIter,
+ "Returns an iterator over the blocks in the operation's region.")
+ .def("__next__", &PyBlockIterator::dunderNext,
+ "Returns the next block in the iteration.");
+ }
+
+private:
+ PyOperationRef operation;
+ MlirBlock next;
+};
+
+/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
+/// we present them as a more full-featured list-like container but optimize
+/// it for forward iteration. Blocks are always owned by a region.
+class PyBlockList {
+public:
+ PyBlockList(PyOperationRef operation, MlirRegion region)
+ : operation(std::move(operation)), region(region) {}
+
+ PyBlockIterator dunderIter() {
+ operation->checkValid();
+ return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+ }
+
+ intptr_t dunderLen() {
+ operation->checkValid();
+ intptr_t count = 0;
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ count += 1;
+ block = mlirBlockGetNextInRegion(block);
+ }
+ return count;
+ }
+
+ PyBlock dunderGetItem(intptr_t index) {
+ operation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nanobind::index_error("attempt to access out of bounds block");
+ }
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ if (index == 0) {
+ return PyBlock(operation, block);
+ }
+ block = mlirBlockGetNextInRegion(block);
+ index -= 1;
+ }
+ throw nanobind::index_error("attempt to access out of bounds block");
+ }
+
+ PyBlock appendBlock(const nanobind::args &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs) {
+ operation->checkValid();
+ MlirBlock block =
+ createBlock(nanobind::cast<nanobind::sequence>(pyArgTypes), pyArgLocs);
+ mlirRegionAppendOwnedBlock(region, block);
+ return PyBlock(operation, block);
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyBlockList>(m, "BlockList")
+ .def("__getitem__", &PyBlockList::dunderGetItem,
+ "Returns the block at the specified index.")
+ .def("__iter__", &PyBlockList::dunderIter,
+ "Returns an iterator over blocks in the operation's region.")
+ .def("__len__", &PyBlockList::dunderLen,
+ "Returns the number of blocks in the operation's region.")
+ .def("append", &PyBlockList::appendBlock,
+ R"(
+ Appends a new block, with argument types as positional args.
+
+ Returns:
+ The created block.
+ )",
+ nanobind::arg("args"), nanobind::kw_only(),
+ nanobind::arg("arg_locs") = std::nullopt);
+ }
+
+private:
+ PyOperationRef operation;
+ MlirRegion region;
+};
+
+class PyOperationIterator {
+public:
+ PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
+ : parentOperation(std::move(parentOperation)), next(next) {}
+
+ PyOperationIterator &dunderIter() { return *this; }
+
+ nanobind::typed<nanobind::object, PyOpView> dunderNext() {
+ parentOperation->checkValid();
+ if (mlirOperationIsNull(next)) {
+ throw nanobind::stop_iteration();
+ }
+
+ PyOperationRef returnOperation =
+ PyOperation::forOperation(parentOperation->getContext(), next);
+ next = mlirOperationGetNextInBlock(next);
+ return returnOperation->createOpView();
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOperationIterator>(m, "OperationIterator")
+ .def("__iter__", &PyOperationIterator::dunderIter,
+ "Returns an iterator over the operations in an operation's block.")
+ .def("__next__", &PyOperationIterator::dunderNext,
+ "Returns the next operation in the iteration.");
+ }
+
+private:
+ PyOperationRef parentOperation;
+ MlirOperation next;
+};
+
+/// Operations are exposed by the C-API as a forward-only linked list. In
+/// Python, we present them as a more full-featured list-like container but
+/// optimize it for forward iteration. Iterable operations are always owned
+/// by a block.
+class PyOperationList {
+public:
+ PyOperationList(PyOperationRef parentOperation, MlirBlock block)
+ : parentOperation(std::move(parentOperation)), block(block) {}
+
+ PyOperationIterator dunderIter() {
+ parentOperation->checkValid();
+ return PyOperationIterator(parentOperation,
+ mlirBlockGetFirstOperation(block));
+ }
+
+ intptr_t dunderLen() {
+ parentOperation->checkValid();
+ intptr_t count = 0;
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ count += 1;
+ childOp = mlirOperationGetNextInBlock(childOp);
+ }
+ return count;
+ }
+
+ nanobind::typed<nanobind::object, PyOpView> dunderGetItem(intptr_t index) {
+ parentOperation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nanobind::index_error("attempt to access out of bounds operation");
+ }
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ if (index == 0) {
+ return PyOperation::forOperation(parentOperation->getContext(), childOp)
+ ->createOpView();
+ }
+ childOp = mlirOperationGetNextInBlock(childOp);
+ index -= 1;
+ }
+ throw nanobind::index_error("attempt to access out of bounds operation");
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOperationList>(m, "OperationList")
+ .def("__getitem__", &PyOperationList::dunderGetItem,
+ "Returns the operation at the specified index.")
+ .def("__iter__", &PyOperationList::dunderIter,
+ "Returns an iterator over operations in the list.")
+ .def("__len__", &PyOperationList::dunderLen,
+ "Returns the number of operations in the list.");
+ }
+
+private:
+ PyOperationRef parentOperation;
+ MlirBlock block;
+};
+
+class PyOpOperand {
+public:
+ PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
+
+ nanobind::typed<nanobind::object, PyOpView> getOwner() {
+ MlirOperation owner = mlirOpOperandGetOwner(opOperand);
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(owner));
+ return PyOperation::forOperation(context, owner)->createOpView();
+ }
+
+ size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpOperand>(m, "OpOperand")
+ .def_prop_ro("owner", &PyOpOperand::getOwner,
+ "Returns the operation that owns this operand.")
+ .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
+ "Returns the operand number in the owning operation.");
+ }
+
+private:
+ MlirOpOperand opOperand;
+};
+
+class PyOpOperandIterator {
+public:
+ PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
+
+ PyOpOperandIterator &dunderIter() { return *this; }
+
+ PyOpOperand dunderNext() {
+ if (mlirOpOperandIsNull(opOperand))
+ throw nanobind::stop_iteration();
+
+ PyOpOperand returnOpOperand(opOperand);
+ opOperand = mlirOpOperandGetNextUse(opOperand);
+ return returnOpOperand;
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpOperandIterator>(m, "OpOperandIterator")
+ .def("__iter__", &PyOpOperandIterator::dunderIter,
+ "Returns an iterator over operands.")
+ .def("__next__", &PyOpOperandIterator::dunderNext,
+ "Returns the next operand in the iteration.");
+ }
+
+private:
+ MlirOpOperand opOperand;
+};
+
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ // and redefine bindDerived.
+ using ClassTy = nanobind::class_<DerivedTy, PyValue>;
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ PyConcreteValue() = default;
+ PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+ : PyValue(operationRef, value) {}
+ PyConcreteValue(PyValue &orig)
+ : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+ /// Attempts to cast the original value to the derived type and throws on
+ /// type mismatches.
+ static MlirValue castFrom(PyValue &orig) {
+ if (!DerivedTy::isaFunction(orig.get())) {
+ auto origRepr =
+ nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
+ throw nanobind::value_error((Twine("Cannot cast value to ") +
+ DerivedTy::pyClassName + " (from " +
+ origRepr + ")")
+ .str()
+ .c_str());
+ }
+ return orig.get();
+ }
+
+ /// Binds the Python module objects to functions of this class.
+ static void bind(nanobind::module_ &m) {
+ auto cls = ClassTy(
+ m, DerivedTy::pyClassName, nanobind::is_generic(),
+ nanobind::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
+ .str()
+ .c_str()));
+ cls.def(nanobind::init<PyValue &>(), nanobind::keep_alive<0, 1>(),
+ nanobind::arg("value"));
+ cls.def_static(
+ "isinstance",
+ [](PyValue &otherValue) -> bool {
+ return DerivedTy::isaFunction(otherValue);
+ },
+ nanobind::arg("other_value"));
+ cls.def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
+ return self.maybeDownCast();
+ });
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+ static constexpr const char *pyClassName = "OpResult";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResult &self) -> nanobind::typed<nanobind::object, PyOperation> {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in "
+ "the IR");
+ return self.getParentOperation().getObject();
+ },
+ "Returns the operation that produces this result.");
+ c.def_prop_ro(
+ "result_number",
+ [](PyOpResult &self) {
+ return mlirOpResultGetResultNumber(self.get());
+ },
+ "Returns the position of this result in the operation's result list.");
+ }
+};
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<nanobind::typed<nanobind::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+ std::vector<nanobind::typed<nanobind::object, PyType>> result;
+ result.reserve(container.size());
+ for (int i = 0, e = container.size(); i < e; ++i) {
+ result.push_back(PyType(context->getRef(),
+ mlirValueGetType(container.getElement(i).get()))
+ .maybeDownCast());
+ }
+ return result;
+}
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+ static constexpr const char *pyClassName = "OpResultList";
+ using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+ PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumResults(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all results in this result list.");
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResultList &self)
+ -> nanobind::typed<nanobind::object, PyOpView> {
+ return self.operation->createOpView();
+ },
+ "Returns the operation that owns this result list.");
+ }
+
+ PyOperationRef &getOperation() { return operation; }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpResultList, PyOpResult>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumResults(operation->get());
+ }
+
+ PyOpResult getRawElement(intptr_t index) {
+ PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+ return PyOpResult(value);
+ }
+
+ PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpResultList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+/// Python wrapper for MlirBlockArgument.
+class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
+ static constexpr const char *pyClassName = "BlockArgument";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyBlockArgument &self) {
+ return PyBlock(self.getParentOperation(),
+ mlirBlockArgumentGetOwner(self.get()));
+ },
+ "Returns the block that owns this argument.");
+ c.def_prop_ro(
+ "arg_number",
+ [](PyBlockArgument &self) {
+ return mlirBlockArgumentGetArgNumber(self.get());
+ },
+ "Returns the position of this argument in the block's argument list.");
+ c.def(
+ "set_type",
+ [](PyBlockArgument &self, PyType type) {
+ return mlirBlockArgumentSetType(self.get(), type);
+ },
+ nanobind::arg("type"), "Sets the type of this block argument.");
+ c.def(
+ "set_location",
+ [](PyBlockArgument &self, PyLocation loc) {
+ return mlirBlockArgumentSetLocation(self.get(), loc);
+ },
+ nanobind::arg("loc"), "Sets the location of this block argument.");
+ }
+};
+
+/// A list of block arguments. Internally, these are stored as consecutive
+/// elements, random access is cheap. The argument list is associated with the
+/// operation that contains the block (detached blocks are not allowed in
+/// Python bindings) and extends its lifetime.
+class PyBlockArgumentList
+ : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
+public:
+ static constexpr const char *pyClassName = "BlockArgumentList";
+ using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
+ PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumArguments(block) : length,
+ step),
+ operation(std::move(operation)), block(block) {}
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyBlockArgumentList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all arguments in this argument list.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
+ /// Returns the number of arguments in the list.
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirBlockGetNumArguments(block);
+ }
+
+ /// Returns `pos`-the element in the list.
+ PyBlockArgument getRawElement(intptr_t pos) {
+ MlirValue argument = mlirBlockGetArgument(block, pos);
+ return PyBlockArgument(operation, argument);
+ }
+
+ /// Returns a sublist of this list.
+ PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyBlockArgumentList(operation, block, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ MlirBlock block;
+};
+
+/// A list of operation operands. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) operand list is associated
+/// with the operation whose operands these are, and thus extends the lifetime
+/// of this operation.
+class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
+public:
+ static constexpr const char *pyClassName = "OpOperandList";
+ using SliceableT = Sliceable<PyOpOperandList, PyValue>;
+
+ PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumOperands(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
+ void dunderSetItem(intptr_t index, PyValue value) {
+ index = wrapIndex(index);
+ mlirOperationSetOperand(operation->get(), index, value.get());
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpOperandList::dunderSetItem,
+ nanobind::arg("index"), nanobind::arg("value"),
+ "Sets the operand at the specified index to a new value.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpOperandList, PyValue>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumOperands(operation->get());
+ }
+
+ PyValue getRawElement(intptr_t pos) {
+ MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(operand))
+ owner = mlirOpResultGetOwner(operand);
+ else if (mlirValueIsABlockArgument(operand))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
+ else
+ assert(false && "Value must be an block arg or op result.");
+ PyOperationRef pyOwner =
+ PyOperation::forOperation(operation->getContext(), owner);
+ return PyValue(pyOwner, operand);
+ }
+
+ PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpOperandList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+/// A list of operation successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation whose successors these are, and thus extends
+/// the lifetime of this operation.
+class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "OpSuccessors";
+
+ PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumSuccessors(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
+ void dunderSetItem(intptr_t index, PyBlock block) {
+ index = wrapIndex(index);
+ mlirOperationSetSuccessor(operation->get(), index, block.get());
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nanobind::arg("index"),
+ nanobind::arg("block"),
+ "Sets the successor block at the specified index.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpSuccessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumSuccessors(operation->get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpSuccessors(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+/// A list of block successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation and block whose successors these are, and thus
+/// extends the lifetime of this operation and block.
+class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockSuccessors";
+
+ PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumSuccessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockSuccessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumSuccessors(block.get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyBlockSuccessors(block, operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
+/// A list of block predecessors. The (returned) predecessor list is
+/// associated with the operation and block whose predecessors these are, and
+/// thus extends the lifetime of this operation and block.
+///
+/// WARNING: This Sliceable is more expensive than the others here because
+/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
+/// operands) anew for each indexed access.
+class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockPredecessors";
+
+ PyBlockPredecessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumPredecessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockPredecessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumPredecessors(block.get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyBlockPredecessors(block, operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
+/// A list of operation attributes. Can be indexed by name, producing
+/// attributes, or by index, producing named attributes.
+class PyOpAttributeMap {
+public:
+ PyOpAttributeMap(PyOperationRef operation)
+ : operation(std::move(operation)) {}
+
+ nanobind::typed<nanobind::object, PyAttribute>
+ dunderGetItemNamed(const std::string &name) {
+ MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
+ toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr)) {
+ throw nanobind::key_error("attempt to access a non-existent attribute");
+ }
+ return PyAttribute(operation->getContext(), attr).maybeDownCast();
+ }
+
+ PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0 || index >= dunderLen()) {
+ throw nanobind::index_error("attempt to access out of bounds attribute");
+ }
+ MlirNamedAttribute namedAttr =
+ mlirOperationGetAttribute(operation->get(), index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data,
+ mlirIdentifierStr(namedAttr.name).length));
+ }
+
+ void dunderSetItem(const std::string &name, const PyAttribute &attr) {
+ mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
+ attr);
+ }
+
+ void dunderDelItem(const std::string &name) {
+ int removed = mlirOperationRemoveAttributeByName(operation->get(),
+ toMlirStringRef(name));
+ if (!removed)
+ throw nanobind::key_error("attempt to delete a non-existent attribute");
+ }
+
+ intptr_t dunderLen() {
+ return mlirOperationGetNumAttributes(operation->get());
+ }
+
+ bool dunderContains(const std::string &name) {
+ return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
+ operation->get(), toMlirStringRef(name)));
+ }
+
+ static void
+ forEachAttr(MlirOperation op,
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
+ intptr_t n = mlirOperationGetNumAttributes(op);
+ for (intptr_t i = 0; i < n; ++i) {
+ MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
+ MlirStringRef name = mlirIdentifierStr(na.name);
+ fn(name, na.attribute);
+ }
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpAttributeMap>(m, "OpAttributeMap")
+ .def("__contains__", &PyOpAttributeMap::dunderContains,
+ nanobind::arg("name"),
+ "Checks if an attribute with the given name exists in the map.")
+ .def("__len__", &PyOpAttributeMap::dunderLen,
+ "Returns the number of attributes in the map.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
+ nanobind::arg("name"), "Gets an attribute by name.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
+ nanobind::arg("index"), "Gets a named attribute by index.")
+ .def("__setitem__", &PyOpAttributeMap::dunderSetItem,
+ nanobind::arg("name"), nanobind::arg("attr"),
+ "Sets an attribute with the given name.")
+ .def("__delitem__", &PyOpAttributeMap::dunderDelItem,
+ nanobind::arg("name"), "Deletes an attribute with the given name.")
+ .def(
+ "__iter__",
+ [](PyOpAttributeMap &self) {
+ nanobind::list keys;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ keys.append(nanobind::str(name.data, name.length));
+ });
+ return nanobind::iter(keys);
+ },
+ "Iterates over attribute names.")
+ .def(
+ "keys",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ out.append(nanobind::str(name.data, name.length));
+ });
+ return out;
+ },
+ "Returns a list of attribute names.")
+ .def(
+ "values",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef, MlirAttribute attr) {
+ out.append(PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast());
+ });
+ return out;
+ },
+ "Returns a list of attribute values.")
+ .def(
+ "items",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute attr) {
+ out.append(nanobind::make_tuple(
+ nanobind::str(name.data, name.length),
+ PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast()));
+ });
+ return out;
+ },
+ "Returns a list of `(name, attribute)` tuples.");
+ }
+
+private:
+ PyOperationRef operation;
+};
+MlirValue getUniqueResult(MlirOperation operation);
} // namespace python
} // namespace mlir
@@ -1345,4 +2352,4 @@ struct type_caster<mlir::python::DefaultingPyLocation>
} // namespace detail
} // namespace nanobind
-#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
+#endif // MLIR_BINDINGS_PYTHON_IRCORE_H
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
similarity index 100%
rename from mlir/lib/Bindings/Python/NanobindUtils.h
rename to mlir/include/mlir/Bindings/Python/NanobindUtils.h
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 0d1d9e89f92f6..a87918a05b126 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Dialect/SMT.h"
#include "mlir-c/IR.h"
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/Globals.cpp
similarity index 97%
rename from mlir/lib/Bindings/Python/IRModule.cpp
rename to mlir/lib/Bindings/Python/Globals.cpp
index 0de2f1711829b..bc6b210426221 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -6,25 +6,27 @@
//
//===----------------------------------------------------------------------===//
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include <optional>
#include <vector>
-#include "Globals.h"
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/Globals.h"
+// clang-format off
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
+// clang-format on
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
using namespace mlir;
-using namespace mlir::python;
// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------
+namespace mlir::python {
PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
@@ -265,3 +267,7 @@ bool PyGlobals::TracebackLoc::isUserTracebackFilename(
}
return isUserTracebackFilenameCache[file];
}
+
+nanobind::ft_mutex PyGlobalDebugFlag::mutex;
+
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 7147f2cbad149..624d8f0fa57ce 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -13,11 +13,13 @@
#include <utility>
#include <vector>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
+#include "mlir/Bindings/Python/IRCore.h"
+// clang-format off
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
#include "mlir-c/IntegerSet.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Support/LLVM.h"
@@ -509,7 +511,8 @@ PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) {
rawIntegerSet);
}
-void mlir::python::populateIRAffine(nb::module_ &m) {
+namespace mlir::python {
+void populateIRAffine(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of PyAffineExpr and derived classes.
//----------------------------------------------------------------------------
@@ -995,3 +998,4 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
PyIntegerSetConstraint::bind(m);
PyIntegerSetConstraintList::bind(m);
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c0a945e3f4f3b..36367e658697c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -12,12 +12,12 @@
#include <string_view>
#include <utility>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
@@ -1799,7 +1799,8 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
"Returns the value of the string attribute as `bytes`");
}
-void mlir::python::populateIRAttributes(nb::module_ &m) {
+namespace mlir::python {
+void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
@@ -1851,4 +1852,18 @@ void mlir::python::populateIRAttributes(nb::module_ &m) {
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
+ nb::register_exception_translator([](const std::exception_ptr &p,
+ void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2e0c2b895216f..88cffb64906d7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
+// clang-format off
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
@@ -22,6 +24,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
+#include <iostream>
#include <optional>
namespace nb = nanobind;
@@ -33,504 +36,7 @@ using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
-static const char kModuleParseDocstring[] =
- R"(Parses a module's assembly format from a string.
-
-Returns a new MlirModule or raises an MLIRError if the parsing fails.
-
-See also: https://mlir.llvm.org/docs/LangRef/
-)";
-
-static const char kDumpDocstring[] =
- "Dumps a debug representation of the object to stderr.";
-
-static const char kValueReplaceAllUsesExceptDocstring[] =
- R"(Replace all uses of this value with the `with` value, except for those
-in `exceptions`. `exceptions` can be either a single operation or a list of
-operations.
-)";
-
-//------------------------------------------------------------------------------
-// Utilities.
-//------------------------------------------------------------------------------
-
-/// Helper for creating an @classmethod.
-template <class Func, typename... Args>
-static nb::object classmethod(Func f, Args... args) {
- nb::object cf = nb::cpp_function(f, args...);
- return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
-}
-
-static nb::object
-createCustomDialectWrapper(const std::string &dialectNamespace,
- nb::object dialectDescriptor) {
- auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
- if (!dialectClass) {
- // Use the base class.
- return nb::cast(PyDialect(std::move(dialectDescriptor)));
- }
-
- // Create the custom implementation.
- return (*dialectClass)(std::move(dialectDescriptor));
-}
-
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(std::string_view s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(const nb::bytes &s) {
- return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
-}
-
-/// Create a block, using the current location context if no locations are
-/// specified.
-static MlirBlock createBlock(const nb::sequence &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- SmallVector<MlirType> argTypes;
- argTypes.reserve(nb::len(pyArgTypes));
- for (const auto &pyType : pyArgTypes)
- argTypes.push_back(nb::cast<PyType &>(pyType));
-
- SmallVector<MlirLocation> argLocs;
- if (pyArgLocs) {
- argLocs.reserve(nb::len(*pyArgLocs));
- for (const auto &pyLoc : *pyArgLocs)
- argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
- } else if (!argTypes.empty()) {
- argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
- }
-
- if (argTypes.size() != argLocs.size())
- throw nb::value_error(("Expected " + Twine(argTypes.size()) +
- " locations, got: " + Twine(argLocs.size()))
- .str()
- .c_str());
- return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
-}
-
-/// Wrapper for the global LLVM debugging flag.
-struct PyGlobalDebugFlag {
- static void set(nb::object &o, bool enable) {
- nb::ft_lock_guard lock(mutex);
- mlirEnableGlobalDebug(enable);
- }
-
- static bool get(const nb::object &) {
- nb::ft_lock_guard lock(mutex);
- return mlirIsGlobalDebugEnabled();
- }
-
- static void bind(nb::module_ &m) {
- // Debug flags.
- nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
- .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
- .def_static(
- "set_types",
- [](const std::string &type) {
- nb::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugType(type.c_str());
- },
- "types"_a, "Sets specific debug types to be produced by LLVM.")
- .def_static(
- "set_types",
- [](const std::vector<std::string> &types) {
- std::vector<const char *> pointers;
- pointers.reserve(types.size());
- for (const std::string &str : types)
- pointers.push_back(str.c_str());
- nb::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
- },
- "types"_a,
- "Sets multiple specific debug types to be produced by LLVM.");
- }
-
-private:
- static nb::ft_mutex mutex;
-};
-
-nb::ft_mutex PyGlobalDebugFlag::mutex;
-
-struct PyAttrBuilderMap {
- static bool dunderContains(const std::string &attributeKind) {
- return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
- }
- static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
- auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
- if (!builder)
- throw nb::key_error(attributeKind.c_str());
- return *builder;
- }
- static void dunderSetItemNamed(const std::string &attributeKind,
- nb::callable func, bool replace) {
- PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
- replace);
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
- .def_static("contains", &PyAttrBuilderMap::dunderContains,
- "attribute_kind"_a,
- "Checks whether an attribute builder is registered for the "
- "given attribute kind.")
- .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
- "attribute_kind"_a,
- "Gets the registered attribute builder for the given "
- "attribute kind.")
- .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
- "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
- "Register an attribute builder for building MLIR "
- "attributes from Python values.");
- }
-};
-
-//------------------------------------------------------------------------------
-// PyBlock
-//------------------------------------------------------------------------------
-
-nb::object PyBlock::getCapsule() {
- return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
-}
-
-//------------------------------------------------------------------------------
-// Collections.
-//------------------------------------------------------------------------------
-
-namespace {
-
-class PyRegionIterator {
-public:
- PyRegionIterator(PyOperationRef operation, int nextIndex)
- : operation(std::move(operation)), nextIndex(nextIndex) {}
-
- PyRegionIterator &dunderIter() { return *this; }
-
- PyRegion dunderNext() {
- operation->checkValid();
- if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
- throw nb::stop_iteration();
- }
- MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
- return PyRegion(operation, region);
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyRegionIterator>(m, "RegionIterator")
- .def("__iter__", &PyRegionIterator::dunderIter,
- "Returns an iterator over the regions in the operation.")
- .def("__next__", &PyRegionIterator::dunderNext,
- "Returns the next region in the iteration.");
- }
-
-private:
- PyOperationRef operation;
- intptr_t nextIndex = 0;
-};
-
-/// Regions of an op are fixed length and indexed numerically so are represented
-/// with a sequence-like container.
-class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
-public:
- static constexpr const char *pyClassName = "RegionSequence";
-
- PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumRegions(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- PyRegionIterator dunderIter() {
- operation->checkValid();
- return PyRegionIterator(operation, startIndex);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__iter__", &PyRegionList::dunderIter,
- "Returns an iterator over the regions in the sequence.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyRegionList, PyRegion>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumRegions(operation->get());
- }
-
- PyRegion getRawElement(intptr_t pos) {
- operation->checkValid();
- return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
- }
-
- PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyRegionList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
-class PyBlockIterator {
-public:
- PyBlockIterator(PyOperationRef operation, MlirBlock next)
- : operation(std::move(operation)), next(next) {}
-
- PyBlockIterator &dunderIter() { return *this; }
-
- PyBlock dunderNext() {
- operation->checkValid();
- if (mlirBlockIsNull(next)) {
- throw nb::stop_iteration();
- }
-
- PyBlock returnBlock(operation, next);
- next = mlirBlockGetNextInRegion(next);
- return returnBlock;
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyBlockIterator>(m, "BlockIterator")
- .def("__iter__", &PyBlockIterator::dunderIter,
- "Returns an iterator over the blocks in the operation's region.")
- .def("__next__", &PyBlockIterator::dunderNext,
- "Returns the next block in the iteration.");
- }
-
-private:
- PyOperationRef operation;
- MlirBlock next;
-};
-
-/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
-/// we present them as a more full-featured list-like container but optimize
-/// it for forward iteration. Blocks are always owned by a region.
-class PyBlockList {
-public:
- PyBlockList(PyOperationRef operation, MlirRegion region)
- : operation(std::move(operation)), region(region) {}
-
- PyBlockIterator dunderIter() {
- operation->checkValid();
- return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
- }
-
- intptr_t dunderLen() {
- operation->checkValid();
- intptr_t count = 0;
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- count += 1;
- block = mlirBlockGetNextInRegion(block);
- }
- return count;
- }
-
- PyBlock dunderGetItem(intptr_t index) {
- operation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nb::index_error("attempt to access out of bounds block");
- }
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- if (index == 0) {
- return PyBlock(operation, block);
- }
- block = mlirBlockGetNextInRegion(block);
- index -= 1;
- }
- throw nb::index_error("attempt to access out of bounds block");
- }
-
- PyBlock appendBlock(const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- operation->checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- mlirRegionAppendOwnedBlock(region, block);
- return PyBlock(operation, block);
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyBlockList>(m, "BlockList")
- .def("__getitem__", &PyBlockList::dunderGetItem,
- "Returns the block at the specified index.")
- .def("__iter__", &PyBlockList::dunderIter,
- "Returns an iterator over blocks in the operation's region.")
- .def("__len__", &PyBlockList::dunderLen,
- "Returns the number of blocks in the operation's region.")
- .def("append", &PyBlockList::appendBlock,
- R"(
- Appends a new block, with argument types as positional args.
-
- Returns:
- The created block.
- )",
- nb::arg("args"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt);
- }
-
-private:
- PyOperationRef operation;
- MlirRegion region;
-};
-
-class PyOperationIterator {
-public:
- PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
- : parentOperation(std::move(parentOperation)), next(next) {}
-
- PyOperationIterator &dunderIter() { return *this; }
-
- nb::typed<nb::object, PyOpView> dunderNext() {
- parentOperation->checkValid();
- if (mlirOperationIsNull(next)) {
- throw nb::stop_iteration();
- }
-
- PyOperationRef returnOperation =
- PyOperation::forOperation(parentOperation->getContext(), next);
- next = mlirOperationGetNextInBlock(next);
- return returnOperation->createOpView();
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOperationIterator>(m, "OperationIterator")
- .def("__iter__", &PyOperationIterator::dunderIter,
- "Returns an iterator over the operations in an operation's block.")
- .def("__next__", &PyOperationIterator::dunderNext,
- "Returns the next operation in the iteration.");
- }
-
-private:
- PyOperationRef parentOperation;
- MlirOperation next;
-};
-
-/// Operations are exposed by the C-API as a forward-only linked list. In
-/// Python, we present them as a more full-featured list-like container but
-/// optimize it for forward iteration. Iterable operations are always owned
-/// by a block.
-class PyOperationList {
-public:
- PyOperationList(PyOperationRef parentOperation, MlirBlock block)
- : parentOperation(std::move(parentOperation)), block(block) {}
-
- PyOperationIterator dunderIter() {
- parentOperation->checkValid();
- return PyOperationIterator(parentOperation,
- mlirBlockGetFirstOperation(block));
- }
-
- intptr_t dunderLen() {
- parentOperation->checkValid();
- intptr_t count = 0;
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- count += 1;
- childOp = mlirOperationGetNextInBlock(childOp);
- }
- return count;
- }
-
- nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
- parentOperation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nb::index_error("attempt to access out of bounds operation");
- }
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- if (index == 0) {
- return PyOperation::forOperation(parentOperation->getContext(), childOp)
- ->createOpView();
- }
- childOp = mlirOperationGetNextInBlock(childOp);
- index -= 1;
- }
- throw nb::index_error("attempt to access out of bounds operation");
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOperationList>(m, "OperationList")
- .def("__getitem__", &PyOperationList::dunderGetItem,
- "Returns the operation at the specified index.")
- .def("__iter__", &PyOperationList::dunderIter,
- "Returns an iterator over operations in the list.")
- .def("__len__", &PyOperationList::dunderLen,
- "Returns the number of operations in the list.");
- }
-
-private:
- PyOperationRef parentOperation;
- MlirBlock block;
-};
-
-class PyOpOperand {
-public:
- PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
-
- nb::typed<nb::object, PyOpView> getOwner() {
- MlirOperation owner = mlirOpOperandGetOwner(opOperand);
- PyMlirContextRef context =
- PyMlirContext::forContext(mlirOperationGetContext(owner));
- return PyOperation::forOperation(context, owner)->createOpView();
- }
-
- size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOpOperand>(m, "OpOperand")
- .def_prop_ro("owner", &PyOpOperand::getOwner,
- "Returns the operation that owns this operand.")
- .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
- "Returns the operand number in the owning operation.");
- }
-
-private:
- MlirOpOperand opOperand;
-};
-
-class PyOpOperandIterator {
-public:
- PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
-
- PyOpOperandIterator &dunderIter() { return *this; }
-
- PyOpOperand dunderNext() {
- if (mlirOpOperandIsNull(opOperand))
- throw nb::stop_iteration();
-
- PyOpOperand returnOpOperand(opOperand);
- opOperand = mlirOpOperandGetNextUse(opOperand);
- return returnOpOperand;
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
- .def("__iter__", &PyOpOperandIterator::dunderIter,
- "Returns an iterator over operands.")
- .def("__next__", &PyOpOperandIterator::dunderNext,
- "Returns the next operand in the iteration.");
- }
-
-private:
- MlirOpOperand opOperand;
-};
-
-} // namespace
-
+namespace mlir::python {
//------------------------------------------------------------------------------
// PyMlirContext
//------------------------------------------------------------------------------
@@ -1413,8 +919,12 @@ nb::object PyOperation::create(std::string_view name,
// Construct the operation.
PyMlirContext::ErrorCapture errors(location.getContext());
MlirOperation operation = mlirOperationCreate(&state);
- if (!operation.ptr)
+ if (!operation.ptr) {
+ for (auto take : errors.take()) {
+ std::cout << take.message << "\n";
+ }
throw MLIRError("Operation creation failed", errors.take());
+ }
PyOperationRef created =
PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);
@@ -1448,163 +958,6 @@ void PyOperation::erase() {
mlirOperationDestroy(operation);
}
-namespace {
-/// CRTP base class for Python MLIR values that subclass Value and should be
-/// castable from it. The value hierarchy is one level deep and is not supposed
-/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- // and redefine bindDerived.
- using ClassTy = nb::class_<DerivedTy, PyValue>;
- using IsAFunctionTy = bool (*)(MlirValue);
-
- PyConcreteValue() = default;
- PyConcreteValue(PyOperationRef operationRef, MlirValue value)
- : PyValue(operationRef, value) {}
- PyConcreteValue(PyValue &orig)
- : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
-
- /// Attempts to cast the original value to the derived type and throws on
- /// type mismatches.
- static MlirValue castFrom(PyValue &orig) {
- if (!DerivedTy::isaFunction(orig.get())) {
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
- throw nb::value_error((Twine("Cannot cast value to ") +
- DerivedTy::pyClassName + " (from " + origRepr +
- ")")
- .str()
- .c_str());
- }
- return orig.get();
- }
-
- /// Binds the Python module objects to functions of this class.
- static void bind(nb::module_ &m) {
- auto cls = ClassTy(
- m, DerivedTy::pyClassName, nb::is_generic(),
- nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
- .str()
- .c_str()));
- cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
- cls.def_static(
- "isinstance",
- [](PyValue &otherValue) -> bool {
- return DerivedTy::isaFunction(otherValue);
- },
- nb::arg("other_value"));
- cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
- return self.maybeDownCast();
- });
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
-} // namespace
-
-/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
- static constexpr const char *pyClassName = "OpResult";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
- assert(mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in "
- "the IR");
- return self.getParentOperation().getObject();
- },
- "Returns the operation that produces this result.");
- c.def_prop_ro(
- "result_number",
- [](PyOpResult &self) {
- return mlirOpResultGetResultNumber(self.get());
- },
- "Returns the position of this result in the operation's result list.");
- }
-};
-
-/// Returns the list of types of the values held by container.
-template <typename Container>
-static std::vector<nb::typed<nb::object, PyType>>
-getValueTypes(Container &container, PyMlirContextRef &context) {
- std::vector<nb::typed<nb::object, PyType>> result;
- result.reserve(container.size());
- for (int i = 0, e = container.size(); i < e; ++i) {
- result.push_back(PyType(context->getRef(),
- mlirValueGetType(container.getElement(i).get()))
- .maybeDownCast());
- }
- return result;
-}
-
-/// A list of operation results. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) result list is associated
-/// with the operation whose results these are, and thus extends the lifetime of
-/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
-public:
- static constexpr const char *pyClassName = "OpResultList";
- using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
-
- PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumResults(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyOpResultList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all results in this result list.");
- c.def_prop_ro(
- "owner",
- [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
- return self.operation->createOpView();
- },
- "Returns the operation that owns this result list.");
- }
-
- PyOperationRef &getOperation() { return operation; }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpResultList, PyOpResult>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumResults(operation->get());
- }
-
- PyOpResult getRawElement(intptr_t index) {
- PyValue value(operation, mlirOperationGetResult(operation->get(), index));
- return PyOpResult(value);
- }
-
- PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpResultList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
//------------------------------------------------------------------------------
// PyOpView
//------------------------------------------------------------------------------
@@ -1706,7 +1059,7 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
}
}
-static MlirValue getUniqueResult(MlirOperation operation) {
+MlirValue getUniqueResult(MlirOperation operation) {
auto numResults = mlirOperationGetNumResults(operation);
if (numResults != 1) {
auto name = mlirIdentifierStr(mlirOperationGetName(operation));
@@ -2319,2640 +1672,11 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
}
-namespace {
-
-/// Python wrapper for MlirBlockArgument.
-class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
- static constexpr const char *pyClassName = "BlockArgument";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyBlockArgument &self) {
- return PyBlock(self.getParentOperation(),
- mlirBlockArgumentGetOwner(self.get()));
- },
- "Returns the block that owns this argument.");
- c.def_prop_ro(
- "arg_number",
- [](PyBlockArgument &self) {
- return mlirBlockArgumentGetArgNumber(self.get());
- },
- "Returns the position of this argument in the block's argument list.");
- c.def(
- "set_type",
- [](PyBlockArgument &self, PyType type) {
- return mlirBlockArgumentSetType(self.get(), type);
- },
- nb::arg("type"), "Sets the type of this block argument.");
- c.def(
- "set_location",
- [](PyBlockArgument &self, PyLocation loc) {
- return mlirBlockArgumentSetLocation(self.get(), loc);
- },
- nb::arg("loc"), "Sets the location of this block argument.");
- }
-};
-
-/// A list of block arguments. Internally, these are stored as consecutive
-/// elements, random access is cheap. The argument list is associated with the
-/// operation that contains the block (detached blocks are not allowed in
-/// Python bindings) and extends its lifetime.
-class PyBlockArgumentList
- : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
-public:
- static constexpr const char *pyClassName = "BlockArgumentList";
- using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
-
- PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumArguments(block) : length,
- step),
- operation(std::move(operation)), block(block) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyBlockArgumentList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all arguments in this argument list.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
-
- /// Returns the number of arguments in the list.
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirBlockGetNumArguments(block);
- }
-
- /// Returns `pos`-the element in the list.
- PyBlockArgument getRawElement(intptr_t pos) {
- MlirValue argument = mlirBlockGetArgument(block, pos);
- return PyBlockArgument(operation, argument);
- }
-
- /// Returns a sublist of this list.
- PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockArgumentList(operation, block, startIndex, length, step);
- }
-
- PyOperationRef operation;
- MlirBlock block;
-};
-
-/// A list of operation operands. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) operand list is associated
-/// with the operation whose operands these are, and thus extends the lifetime
-/// of this operation.
-class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
-public:
- static constexpr const char *pyClassName = "OpOperandList";
- using SliceableT = Sliceable<PyOpOperandList, PyValue>;
-
- PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumOperands(operation->get())
- : length,
- step),
- operation(operation) {}
-
- void dunderSetItem(intptr_t index, PyValue value) {
- index = wrapIndex(index);
- mlirOperationSetOperand(operation->get(), index, value.get());
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpOperandList::dunderSetItem, nb::arg("index"),
- nb::arg("value"),
- "Sets the operand at the specified index to a new value.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpOperandList, PyValue>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumOperands(operation->get());
- }
-
- PyValue getRawElement(intptr_t pos) {
- MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
- MlirOperation owner;
- if (mlirValueIsAOpResult(operand))
- owner = mlirOpResultGetOwner(operand);
- else if (mlirValueIsABlockArgument(operand))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
- else
- assert(false && "Value must be an block arg or op result.");
- PyOperationRef pyOwner =
- PyOperation::forOperation(operation->getContext(), owner);
- return PyValue(pyOwner, operand);
- }
-
- PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpOperandList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
-/// A list of operation successors. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) successor list is
-/// associated with the operation whose successors these are, and thus extends
-/// the lifetime of this operation.
-class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "OpSuccessors";
-
- PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumSuccessors(operation->get())
- : length,
- step),
- operation(operation) {}
-
- void dunderSetItem(intptr_t index, PyBlock block) {
- index = wrapIndex(index);
- mlirOperationSetSuccessor(operation->get(), index, block.get());
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg("index"),
- nb::arg("block"), "Sets the successor block at the specified index.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpSuccessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumSuccessors(operation->get());
- }
-
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
- return PyBlock(operation, block);
- }
-
- PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpSuccessors(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
-/// A list of block successors. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) successor list is
-/// associated with the operation and block whose successors these are, and thus
-/// extends the lifetime of this operation and block.
-class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "BlockSuccessors";
-
- PyBlockSuccessors(PyBlock block, PyOperationRef operation,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumSuccessors(block.get())
- : length,
- step),
- operation(operation), block(block) {}
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockSuccessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumSuccessors(block.get());
- }
-
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
-
- PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyBlockSuccessors(block, operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
- PyBlock block;
-};
-
-/// A list of block predecessors. The (returned) predecessor list is
-/// associated with the operation and block whose predecessors these are, and
-/// thus extends the lifetime of this operation and block.
-///
-/// WARNING: This Sliceable is more expensive than the others here because
-/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
-/// operands) anew for each indexed access.
-class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "BlockPredecessors";
-
- PyBlockPredecessors(PyBlock block, PyOperationRef operation,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumPredecessors(block.get())
- : length,
- step),
- operation(operation), block(block) {}
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockPredecessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumPredecessors(block.get());
- }
-
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
-
- PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockPredecessors(block, operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
- PyBlock block;
-};
-
-/// A list of operation attributes. Can be indexed by name, producing
-/// attributes, or by index, producing named attributes.
-class PyOpAttributeMap {
-public:
- PyOpAttributeMap(PyOperationRef operation)
- : operation(std::move(operation)) {}
-
- nb::typed<nb::object, PyAttribute>
- dunderGetItemNamed(const std::string &name) {
- MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (mlirAttributeIsNull(attr)) {
- throw nb::key_error("attempt to access a non-existent attribute");
- }
- return PyAttribute(operation->getContext(), attr).maybeDownCast();
- }
-
- PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0 || index >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr =
- mlirOperationGetAttribute(operation->get(), index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data,
- mlirIdentifierStr(namedAttr.name).length));
- }
-
- void dunderSetItem(const std::string &name, const PyAttribute &attr) {
- mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
- attr);
- }
-
- void dunderDelItem(const std::string &name) {
- int removed = mlirOperationRemoveAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (!removed)
- throw nb::key_error("attempt to delete a non-existent attribute");
- }
-
- intptr_t dunderLen() {
- return mlirOperationGetNumAttributes(operation->get());
- }
-
- bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
- operation->get(), toMlirStringRef(name)));
- }
-
- static void
- forEachAttr(MlirOperation op,
- llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
- intptr_t n = mlirOperationGetNumAttributes(op);
- for (intptr_t i = 0; i < n; ++i) {
- MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
- MlirStringRef name = mlirIdentifierStr(na.name);
- fn(name, na.attribute);
- }
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
- .def("__contains__", &PyOpAttributeMap::dunderContains, nb::arg("name"),
- "Checks if an attribute with the given name exists in the map.")
- .def("__len__", &PyOpAttributeMap::dunderLen,
- "Returns the number of attributes in the map.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
- nb::arg("name"), "Gets an attribute by name.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
- nb::arg("index"), "Gets a named attribute by index.")
- .def("__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg("name"),
- nb::arg("attr"), "Sets an attribute with the given name.")
- .def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"),
- "Deletes an attribute with the given name.")
- .def(
- "__iter__",
- [](PyOpAttributeMap &self) {
- nb::list keys;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- keys.append(nb::str(name.data, name.length));
- });
- return nb::iter(keys);
- },
- "Iterates over attribute names.")
- .def(
- "keys",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- out.append(nb::str(name.data, name.length));
- });
- return out;
- },
- "Returns a list of attribute names.")
- .def(
- "values",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef, MlirAttribute attr) {
- out.append(PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast());
- });
- return out;
- },
- "Returns a list of attribute values.")
- .def(
- "items",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute attr) {
- out.append(nb::make_tuple(
- nb::str(name.data, name.length),
- PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast()));
- });
- return out;
- },
- "Returns a list of `(name, attribute)` tuples.");
- }
-
-private:
- PyOperationRef operation;
-};
-
-// see
-// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
-
-#ifndef _Py_CAST
-#define _Py_CAST(type, expr) ((type)(expr))
-#endif
-
-// Static inline functions should use _Py_NULL rather than using directly NULL
-// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
-// _Py_NULL is defined as nullptr.
-#ifndef _Py_NULL
-#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
- (defined(__cplusplus) && __cplusplus >= 201103)
-#define _Py_NULL nullptr
-#else
-#define _Py_NULL NULL
-#endif
-#endif
-
-// Python 3.10.0a3
-#if PY_VERSION_HEX < 0x030A00A3
-
-// bpo-42262 added Py_XNewRef()
-#if !defined(Py_XNewRef)
-[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
- Py_XINCREF(obj);
- return obj;
-}
-#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
-#endif
-
-// bpo-42262 added Py_NewRef()
-#if !defined(Py_NewRef)
-[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
- Py_INCREF(obj);
- return obj;
-}
-#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
-#endif
-
-#endif // Python 3.10.0a3
-
-// Python 3.9.0b1
-#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
-
-// bpo-40429 added PyThreadState_GetFrame()
-PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
- assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
- return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
-}
-
-// bpo-40421 added PyFrame_GetBack()
-PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
- assert(frame != _Py_NULL && "expected frame != _Py_NULL");
- return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
-}
-
-// bpo-40421 added PyFrame_GetCode()
-PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
- assert(frame != _Py_NULL && "expected frame != _Py_NULL");
- assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
- return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
-}
-
-#endif // Python 3.9.0b1
-
-MlirLocation tracebackToLocation(MlirContext ctx) {
- size_t framesLimit =
- PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
- // Use a thread_local here to avoid requiring a large amount of space.
- thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
- frames;
- size_t count = 0;
-
- nb::gil_scoped_acquire acquire;
- PyThreadState *tstate = PyThreadState_GET();
- PyFrameObject *next;
- PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
- // In the increment expression:
- // 1. get the next prev frame;
- // 2. decrement the ref count on the current frame (in order that it can get
- // gc'd, along with any objects in its closure and etc);
- // 3. set current = next.
- for (; pyFrame != nullptr && count < framesLimit;
- next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
- PyCodeObject *code = PyFrame_GetCode(pyFrame);
- auto fileNameStr =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
- llvm::StringRef fileName(fileNameStr);
- if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
- continue;
-
- // co_qualname and PyCode_Addr2Location added in py3.11
-#if PY_VERSION_HEX < 0x030B00F0
- std::string name =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
- llvm::StringRef funcName(name);
- int startLine = PyFrame_GetLineNumber(pyFrame);
- MlirLocation loc =
- mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
-#else
- std::string name =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
- llvm::StringRef funcName(name);
- int startLine, startCol, endLine, endCol;
- int lasti = PyFrame_GetLasti(pyFrame);
- if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
- &endCol)) {
- throw nb::python_error();
- }
- MlirLocation loc = mlirLocationFileLineColRangeGet(
- ctx, wrap(fileName), startLine, startCol, endLine, endCol);
-#endif
-
- frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
- ++count;
- }
- // When the loop breaks (after the last iter), current frame (if non-null)
- // is leaked without this.
- Py_XDECREF(pyFrame);
-
- if (count == 0)
- return mlirLocationUnknownGet(ctx);
-
- MlirLocation callee = frames[0];
- assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
- if (count == 1)
- return callee;
-
- MlirLocation caller = frames[count - 1];
- assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
- for (int i = count - 2; i >= 1; i--)
- caller = mlirLocationCallSiteGet(frames[i], caller);
-
- return mlirLocationCallSiteGet(callee, caller);
-}
-
-PyLocation
-maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
- if (location.has_value())
- return location.value();
- if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
- return DefaultingPyLocation::resolve();
-
- PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
- MlirLocation mlirLoc = tracebackToLocation(ctx.get());
- PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
- return {ref, mlirLoc};
-}
-
-} // namespace
-
-//------------------------------------------------------------------------------
-// Populates the core exports of the 'ir' submodule.
-//------------------------------------------------------------------------------
-
-void mlir::python::populateIRCore(nb::module_ &m) {
- // disable leak warnings which tend to be false positives.
- nb::set_leak_warnings(false);
- //----------------------------------------------------------------------------
- // Enums.
- //----------------------------------------------------------------------------
- nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
- .value("ERROR", MlirDiagnosticError)
- .value("WARNING", MlirDiagnosticWarning)
- .value("NOTE", MlirDiagnosticNote)
- .value("REMARK", MlirDiagnosticRemark);
-
- nb::enum_<MlirWalkOrder>(m, "WalkOrder")
- .value("PRE_ORDER", MlirWalkPreOrder)
- .value("POST_ORDER", MlirWalkPostOrder);
-
- nb::enum_<MlirWalkResult>(m, "WalkResult")
- .value("ADVANCE", MlirWalkResultAdvance)
- .value("INTERRUPT", MlirWalkResultInterrupt)
- .value("SKIP", MlirWalkResultSkip);
-
- //----------------------------------------------------------------------------
- // Mapping of Diagnostics.
- //----------------------------------------------------------------------------
- nb::class_<PyDiagnostic>(m, "Diagnostic")
- .def_prop_ro("severity", &PyDiagnostic::getSeverity,
- "Returns the severity of the diagnostic.")
- .def_prop_ro("location", &PyDiagnostic::getLocation,
- "Returns the location associated with the diagnostic.")
- .def_prop_ro("message", &PyDiagnostic::getMessage,
- "Returns the message text of the diagnostic.")
- .def_prop_ro("notes", &PyDiagnostic::getNotes,
- "Returns a tuple of attached note diagnostics.")
- .def(
- "__str__",
- [](PyDiagnostic &self) -> nb::str {
- if (!self.isValid())
- return nb::str("<Invalid Diagnostic>");
- return self.getMessage();
- },
- "Returns the diagnostic message as a string.");
-
- nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
- .def(
- "__init__",
- [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
- new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
- },
- "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
- .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
- "The severity level of the diagnostic.")
- .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
- "The location associated with the diagnostic.")
- .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
- "The message text of the diagnostic.")
- .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
- "List of attached note diagnostics.")
- .def(
- "__str__",
- [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
- "Returns the diagnostic message as a string.");
-
- nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
- .def("detach", &PyDiagnosticHandler::detach,
- "Detaches the diagnostic handler from the context.")
- .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
- "Returns True if the handler is attached to a context.")
- .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
- "Returns True if an error was encountered during diagnostic "
- "handling.")
- .def("__enter__", &PyDiagnosticHandler::contextEnter,
- "Enters the diagnostic handler as a context manager.")
- .def("__exit__", &PyDiagnosticHandler::contextExit,
- nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none(),
- "Exits the diagnostic handler context manager.");
-
- // Expose DefaultThreadPool to python
- nb::class_<PyThreadPool>(m, "ThreadPool")
- .def(
- "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
- "Creates a new thread pool with default concurrency.")
- .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
- "Returns the maximum number of threads in the pool.")
- .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
- "Returns the raw pointer to the LLVM thread pool as a string.");
-
- nb::class_<PyMlirContext>(m, "Context")
- .def(
- "__init__",
- [](PyMlirContext &self) {
- MlirContext context = mlirContextCreateWithThreading(false);
- new (&self) PyMlirContext(context);
- },
- R"(
- Creates a new MLIR context.
-
- The context is the top-level container for all MLIR objects. It owns the storage
- for types, attributes, locations, and other core IR objects. A context can be
- configured to allow or disallow unregistered dialects and can have dialects
- loaded on-demand.)")
- .def_static("_get_live_count", &PyMlirContext::getLiveCount,
- "Gets the number of live Context objects.")
- .def(
- "_get_context_again",
- [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
- PyMlirContextRef ref = PyMlirContext::forContext(self.get());
- return ref.releaseObject();
- },
- "Gets another reference to the same context.")
- .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
- "Gets the number of live modules owned by this context.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule,
- "Gets a capsule wrapping the MlirContext.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyMlirContext::createFromCapsule,
- "Creates a Context from a capsule wrapping MlirContext.")
- .def("__enter__", &PyMlirContext::contextEnter,
- "Enters the context as a context manager.")
- .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none(),
- "Exits the context manager.")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/)
- -> std::optional<nb::typed<nb::object, PyMlirContext>> {
- auto *context = PyThreadContextEntry::getDefaultContext();
- if (!context)
- return {};
- return nb::cast(context);
- },
- nb::sig("def current(/) -> Context | None"),
- "Gets the Context bound to the current thread or returns None if no "
- "context is set.")
- .def_prop_ro(
- "dialects",
- [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Gets a container for accessing dialects by name.")
- .def_prop_ro(
- "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Alias for `dialects`.")
- .def(
- "get_dialect_descriptor",
- [=](PyMlirContext &self, std::string &name) {
- MlirDialect dialect = mlirContextGetOrLoadDialect(
- self.get(), {name.data(), name.size()});
- if (mlirDialectIsNull(dialect)) {
- throw nb::value_error(
- (Twine("Dialect '") + name + "' not found").str().c_str());
- }
- return PyDialectDescriptor(self.getRef(), dialect);
- },
- nb::arg("dialect_name"),
- "Gets or loads a dialect by name, returning its descriptor object.")
- .def_prop_rw(
- "allow_unregistered_dialects",
- [](PyMlirContext &self) -> bool {
- return mlirContextGetAllowUnregisteredDialects(self.get());
- },
- [](PyMlirContext &self, bool value) {
- mlirContextSetAllowUnregisteredDialects(self.get(), value);
- },
- "Controls whether unregistered dialects are allowed in this context.")
- .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
- nb::arg("callback"),
- "Attaches a diagnostic handler that will receive callbacks.")
- .def(
- "enable_multithreading",
- [](PyMlirContext &self, bool enable) {
- mlirContextEnableMultithreading(self.get(), enable);
- },
- nb::arg("enable"),
- R"(
- Enables or disables multi-threading support in the context.
-
- Args:
- enable: Whether to enable (True) or disable (False) multi-threading.
- )")
- .def(
- "set_thread_pool",
- [](PyMlirContext &self, PyThreadPool &pool) {
- // we should disable multi-threading first before setting
- // new thread pool otherwise the assert in
- // MLIRContext::setThreadPool will be raised.
- mlirContextEnableMultithreading(self.get(), false);
- mlirContextSetThreadPool(self.get(), pool.get());
- },
- R"(
- Sets a custom thread pool for the context to use.
-
- Args:
- pool: A ThreadPool object to use for parallel operations.
-
- Note:
- Multi-threading is automatically disabled before setting the thread pool.)")
- .def(
- "get_num_threads",
- [](PyMlirContext &self) {
- return mlirContextGetNumThreads(self.get());
- },
- "Gets the number of threads in the context's thread pool.")
- .def(
- "_mlir_thread_pool_ptr",
- [](PyMlirContext &self) {
- MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
- std::stringstream ss;
- ss << pool.ptr;
- return ss.str();
- },
- "Gets the raw pointer to the LLVM thread pool as a string.")
- .def(
- "is_registered_operation",
- [](PyMlirContext &self, std::string &name) {
- return mlirContextIsRegisteredOperation(
- self.get(), MlirStringRef{name.data(), name.size()});
- },
- nb::arg("operation_name"),
- R"(
- Checks whether an operation with the given name is registered.
-
- Args:
- operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
-
- Returns:
- True if the operation is registered, False otherwise.)")
- .def(
- "append_dialect_registry",
- [](PyMlirContext &self, PyDialectRegistry ®istry) {
- mlirContextAppendDialectRegistry(self.get(), registry);
- },
- nb::arg("registry"),
- R"(
- Appends the contents of a dialect registry to the context.
-
- Args:
- registry: A DialectRegistry containing dialects to append.)")
- .def_prop_rw("emit_error_diagnostics",
- &PyMlirContext::getEmitErrorDiagnostics,
- &PyMlirContext::setEmitErrorDiagnostics,
- R"(
- Controls whether error diagnostics are emitted to diagnostic handlers.
-
- By default, error diagnostics are captured and reported through MLIRError exceptions.)")
- .def(
- "load_all_available_dialects",
- [](PyMlirContext &self) {
- mlirContextLoadAllAvailableDialects(self.get());
- },
- R"(
- Loads all dialects available in the registry into the context.
-
- This eagerly loads all dialects that have been registered, making them
- immediately available for use.)");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialectDescriptor
- //----------------------------------------------------------------------------
- nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
- .def_prop_ro(
- "namespace",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns = mlirDialectGetNamespace(self.get());
- return nb::str(ns.data, ns.length);
- },
- "Returns the namespace of the dialect.")
- .def(
- "__repr__",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns = mlirDialectGetNamespace(self.get());
- std::string repr("<DialectDescriptor ");
- repr.append(ns.data, ns.length);
- repr.append(">");
- return repr;
- },
- nb::sig("def __repr__(self) -> str"),
- "Returns a string representation of the dialect descriptor.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialects
- //----------------------------------------------------------------------------
- nb::class_<PyDialects>(m, "Dialects")
- .def(
- "__getitem__",
- [=](PyDialects &self, std::string keyName) {
- MlirDialect dialect =
- self.getDialectForKey(keyName, /*attrError=*/false);
- nb::object descriptor =
- nb::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(keyName, std::move(descriptor));
- },
- "Gets a dialect by name using subscript notation.")
- .def(
- "__getattr__",
- [=](PyDialects &self, std::string attrName) {
- MlirDialect dialect =
- self.getDialectForKey(attrName, /*attrError=*/true);
- nb::object descriptor =
- nb::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(attrName, std::move(descriptor));
- },
- "Gets a dialect by name using attribute notation.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialect
- //----------------------------------------------------------------------------
- nb::class_<PyDialect>(m, "Dialect")
- .def(nb::init<nb::object>(), nb::arg("descriptor"),
- "Creates a Dialect from a DialectDescriptor.")
- .def_prop_ro(
- "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
- "Returns the DialectDescriptor for this dialect.")
- .def(
- "__repr__",
- [](const nb::object &self) {
- auto clazz = self.attr("__class__");
- return nb::str("<Dialect ") +
- self.attr("descriptor").attr("namespace") +
- nb::str(" (class ") + clazz.attr("__module__") +
- nb::str(".") + clazz.attr("__name__") + nb::str(")>");
- },
- nb::sig("def __repr__(self) -> str"),
- "Returns a string representation of the dialect.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialectRegistry
- //----------------------------------------------------------------------------
- nb::class_<PyDialectRegistry>(m, "DialectRegistry")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule,
- "Gets a capsule wrapping the MlirDialectRegistry.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyDialectRegistry::createFromCapsule,
- "Creates a DialectRegistry from a capsule wrapping "
- "`MlirDialectRegistry`.")
- .def(nb::init<>(), "Creates a new empty dialect registry.");
-
- //----------------------------------------------------------------------------
- // Mapping of Location
- //----------------------------------------------------------------------------
- nb::class_<PyLocation>(m, "Location")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule,
- "Gets a capsule wrapping the MlirLocation.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule,
- "Creates a Location from a capsule wrapping MlirLocation.")
- .def("__enter__", &PyLocation::contextEnter,
- "Enters the location as a context manager.")
- .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none(),
- "Exits the location context manager.")
- .def(
- "__eq__",
- [](PyLocation &self, PyLocation &other) -> bool {
- return mlirLocationEqual(self, other);
- },
- "Compares two locations for equality.")
- .def(
- "__eq__", [](PyLocation &self, nb::object other) { return false; },
- "Compares location with non-location object (always returns False).")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/) -> std::optional<PyLocation *> {
- auto *loc = PyThreadContextEntry::getDefaultLocation();
- if (!loc)
- return std::nullopt;
- return loc;
- },
- // clang-format off
- nb::sig("def current(/) -> Location | None"),
- // clang-format on
- "Gets the Location bound to the current thread or raises ValueError.")
- .def_static(
- "unknown",
- [](DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationUnknownGet(context->get()));
- },
- nb::arg("context") = nb::none(),
- "Gets a Location representing an unknown location.")
- .def_static(
- "callsite",
- [](PyLocation callee, const std::vector<PyLocation> &frames,
- DefaultingPyMlirContext context) {
- if (frames.empty())
- throw nb::value_error("No caller frames provided.");
- MlirLocation caller = frames.back().get();
- for (const PyLocation &frame :
- llvm::reverse(llvm::ArrayRef(frames).drop_back()))
- caller = mlirLocationCallSiteGet(frame.get(), caller);
- return PyLocation(context->getRef(),
- mlirLocationCallSiteGet(callee.get(), caller));
- },
- nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
- "Gets a Location representing a caller and callsite.")
- .def("is_a_callsite", mlirLocationIsACallSite,
- "Returns True if this location is a CallSiteLoc.")
- .def_prop_ro(
- "callee",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationCallSiteGetCallee(self));
- },
- "Gets the callee location from a CallSiteLoc.")
- .def_prop_ro(
- "caller",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationCallSiteGetCaller(self));
- },
- "Gets the caller location from a CallSiteLoc.")
- .def_static(
- "file",
- [](std::string filename, int line, int col,
- DefaultingPyMlirContext context) {
- return PyLocation(
- context->getRef(),
- mlirLocationFileLineColGet(
- context->get(), toMlirStringRef(filename), line, col));
- },
- nb::arg("filename"), nb::arg("line"), nb::arg("col"),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a file, line and column.")
- .def_static(
- "file",
- [](std::string filename, int startLine, int startCol, int endLine,
- int endCol, DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationFileLineColRangeGet(
- context->get(), toMlirStringRef(filename),
- startLine, startCol, endLine, endCol));
- },
- nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
- nb::arg("end_line"), nb::arg("end_col"),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a file, line and column range.")
- .def("is_a_file", mlirLocationIsAFileLineColRange,
- "Returns True if this location is a FileLineColLoc.")
- .def_prop_ro(
- "filename",
- [](MlirLocation loc) {
- return mlirIdentifierStr(
- mlirLocationFileLineColRangeGetFilename(loc));
- },
- "Gets the filename from a FileLineColLoc.")
- .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
- "Gets the start line number from a `FileLineColLoc`.")
- .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
- "Gets the start column number from a `FileLineColLoc`.")
- .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
- "Gets the end line number from a `FileLineColLoc`.")
- .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
- "Gets the end column number from a `FileLineColLoc`.")
- .def_static(
- "fused",
- [](const std::vector<PyLocation> &pyLocations,
- std::optional<PyAttribute> metadata,
- DefaultingPyMlirContext context) {
- llvm::SmallVector<MlirLocation, 4> locations;
- locations.reserve(pyLocations.size());
- for (auto &pyLocation : pyLocations)
- locations.push_back(pyLocation.get());
- MlirLocation location = mlirLocationFusedGet(
- context->get(), locations.size(), locations.data(),
- metadata ? metadata->get() : MlirAttribute{0});
- return PyLocation(context->getRef(), location);
- },
- nb::arg("locations"), nb::arg("metadata") = nb::none(),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a fused location with optional "
- "metadata.")
- .def("is_a_fused", mlirLocationIsAFused,
- "Returns True if this location is a `FusedLoc`.")
- .def_prop_ro(
- "locations",
- [](PyLocation &self) {
- unsigned numLocations = mlirLocationFusedGetNumLocations(self);
- std::vector<MlirLocation> locations(numLocations);
- if (numLocations)
- mlirLocationFusedGetLocations(self, locations.data());
- std::vector<PyLocation> pyLocations{};
- pyLocations.reserve(numLocations);
- for (unsigned i = 0; i < numLocations; ++i)
- pyLocations.emplace_back(self.getContext(), locations[i]);
- return pyLocations;
- },
- "Gets the list of locations from a `FusedLoc`.")
- .def_static(
- "name",
- [](std::string name, std::optional<PyLocation> childLoc,
- DefaultingPyMlirContext context) {
- return PyLocation(
- context->getRef(),
- mlirLocationNameGet(
- context->get(), toMlirStringRef(name),
- childLoc ? childLoc->get()
- : mlirLocationUnknownGet(context->get())));
- },
- nb::arg("name"), nb::arg("childLoc") = nb::none(),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a named location with optional child "
- "location.")
- .def("is_a_name", mlirLocationIsAName,
- "Returns True if this location is a `NameLoc`.")
- .def_prop_ro(
- "name_str",
- [](MlirLocation loc) {
- return mlirIdentifierStr(mlirLocationNameGetName(loc));
- },
- "Gets the name string from a `NameLoc`.")
- .def_prop_ro(
- "child_loc",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationNameGetChildLoc(self));
- },
- "Gets the child location from a `NameLoc`.")
- .def_static(
- "from_attr",
- [](PyAttribute &attribute, DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationFromAttribute(attribute));
- },
- nb::arg("attribute"), nb::arg("context") = nb::none(),
- "Gets a Location from a `LocationAttr`.")
- .def_prop_ro(
- "context",
- [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Location`.")
- .def_prop_ro(
- "attr",
- [](PyLocation &self) {
- return PyAttribute(self.getContext(),
- mlirLocationGetAttribute(self));
- },
- "Get the underlying `LocationAttr`.")
- .def(
- "emit_error",
- [](PyLocation &self, std::string message) {
- mlirEmitError(self, message.c_str());
- },
- nb::arg("message"),
- R"(
- Emits an error diagnostic at this location.
-
- Args:
- message: The error message to emit.)")
- .def(
- "__repr__",
- [](PyLocation &self) {
- PyPrintAccumulator printAccum;
- mlirLocationPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly representation of the location.");
-
- //----------------------------------------------------------------------------
- // Mapping of Module
- //----------------------------------------------------------------------------
- nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule,
- "Gets a capsule wrapping the MlirModule.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
- R"(
- Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
-
- This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
- prevent double-frees (of the underlying `mlir::Module`).)")
- .def("_clear_mlir_module", &PyModule::clearMlirModule,
- R"(
- Clears the internal MLIR module reference.
-
- This is used internally to prevent double-free when ownership is transferred
- via the C API capsule mechanism. Not intended for normal use.)")
- .def_static(
- "parse",
- [](const std::string &moduleAsm, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParse(
- context->get(), toMlirStringRef(moduleAsm));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "parse",
- [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParse(
- context->get(), toMlirStringRef(moduleAsm));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "parseFile",
- [](const std::string &path, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParseFromFile(
- context->get(), toMlirStringRef(path));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("path"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "create",
- [](const std::optional<PyLocation> &loc)
- -> nb::typed<nb::object, PyModule> {
- PyLocation pyLoc = maybeGetTracebackLocation(loc);
- MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("loc") = nb::none(), "Creates an empty module.")
- .def_prop_ro(
- "context",
- [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that created the `Module`.")
- .def_prop_ro(
- "operation",
- [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
- return PyOperation::forOperation(self.getContext(),
- mlirModuleGetOperation(self.get()),
- self.getRef().releaseObject())
- .releaseObject();
- },
- "Accesses the module as an operation.")
- .def_prop_ro(
- "body",
- [](PyModule &self) {
- PyOperationRef moduleOp = PyOperation::forOperation(
- self.getContext(), mlirModuleGetOperation(self.get()),
- self.getRef().releaseObject());
- PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
- return returnBlock;
- },
- "Return the block for this module.")
- .def(
- "dump",
- [](PyModule &self) {
- mlirOperationDump(mlirModuleGetOperation(self.get()));
- },
- kDumpDocstring)
- .def(
- "__str__",
- [](const nb::object &self) {
- // Defer to the operation's __str__.
- return self.attr("operation").attr("__str__")();
- },
- nb::sig("def __str__(self) -> str"),
- R"(
- Gets the assembly form of the operation with default options.
-
- If more advanced control over the assembly formatting or I/O options is needed,
- use the dedicated print or get_asm method, which supports keyword arguments to
- customize behavior.
- )")
- .def(
- "__eq__",
- [](PyModule &self, PyModule &other) {
- return mlirModuleEqual(self.get(), other.get());
- },
- "other"_a, "Compares two modules for equality.")
- .def(
- "__hash__",
- [](PyModule &self) { return mlirModuleHashValue(self.get()); },
- "Returns the hash value of the module.");
-
- //----------------------------------------------------------------------------
- // Mapping of Operation.
- //----------------------------------------------------------------------------
- nb::class_<PyOperationBase>(m, "_OperationBase")
- .def_prop_ro(
- MLIR_PYTHON_CAPI_PTR_ATTR,
- [](PyOperationBase &self) {
- return self.getOperation().getCapsule();
- },
- "Gets a capsule wrapping the `MlirOperation`.")
- .def(
- "__eq__",
- [](PyOperationBase &self, PyOperationBase &other) {
- return mlirOperationEqual(self.getOperation().get(),
- other.getOperation().get());
- },
- "Compares two operations for equality.")
- .def(
- "__eq__",
- [](PyOperationBase &self, nb::object other) { return false; },
- "Compares operation with non-operation object (always returns "
- "False).")
- .def(
- "__hash__",
- [](PyOperationBase &self) {
- return mlirOperationHashValue(self.getOperation().get());
- },
- "Returns the hash value of the operation.")
- .def_prop_ro(
- "attributes",
- [](PyOperationBase &self) {
- return PyOpAttributeMap(self.getOperation().getRef());
- },
- "Returns a dictionary-like map of operation attributes.")
- .def_prop_ro(
- "context",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
- PyOperation &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- return concreteOperation.getContext().getObject();
- },
- "Context that owns the operation.")
- .def_prop_ro(
- "name",
- [](PyOperationBase &self) {
- auto &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- MlirOperation operation = concreteOperation.get();
- return mlirIdentifierStr(mlirOperationGetName(operation));
- },
- "Returns the fully qualified name of the operation.")
- .def_prop_ro(
- "operands",
- [](PyOperationBase &self) {
- return PyOpOperandList(self.getOperation().getRef());
- },
- "Returns the list of operation operands.")
- .def_prop_ro(
- "regions",
- [](PyOperationBase &self) {
- return PyRegionList(self.getOperation().getRef());
- },
- "Returns the list of operation regions.")
- .def_prop_ro(
- "results",
- [](PyOperationBase &self) {
- return PyOpResultList(self.getOperation().getRef());
- },
- "Returns the list of Operation results.")
- .def_prop_ro(
- "result",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
- auto &operation = self.getOperation();
- return PyOpResult(operation.getRef(), getUniqueResult(operation))
- .maybeDownCast();
- },
- "Shortcut to get an op result if it has only one (throws an error "
- "otherwise).")
- .def_prop_rw(
- "location",
- [](PyOperationBase &self) {
- PyOperation &operation = self.getOperation();
- return PyLocation(operation.getContext(),
- mlirOperationGetLocation(operation.get()));
- },
- [](PyOperationBase &self, const PyLocation &location) {
- PyOperation &operation = self.getOperation();
- mlirOperationSetLocation(operation.get(), location.get());
- },
- nb::for_getter("Returns the source location the operation was "
- "defined or derived from."),
- nb::for_setter("Sets the source location the operation was defined "
- "or derived from."))
- .def_prop_ro(
- "parent",
- [](PyOperationBase &self)
- -> std::optional<nb::typed<nb::object, PyOperation>> {
- auto parent = self.getOperation().getParentOperation();
- if (parent)
- return parent->getObject();
- return {};
- },
- "Returns the parent operation, or `None` if at top level.")
- .def(
- "__str__",
- [](PyOperationBase &self) {
- return self.getAsm(/*binary=*/false,
- /*largeElementsLimit=*/std::nullopt,
- /*largeResourceLimit=*/std::nullopt,
- /*enableDebugInfo=*/false,
- /*prettyDebugInfo=*/false,
- /*printGenericOpForm=*/false,
- /*useLocalScope=*/false,
- /*useNameLocAsPrefix=*/false,
- /*assumeVerified=*/false,
- /*skipRegions=*/false);
- },
- nb::sig("def __str__(self) -> str"),
- "Returns the assembly form of the operation.")
- .def("print",
- nb::overload_cast<PyAsmState &, nb::object, bool>(
- &PyOperationBase::print),
- nb::arg("state"), nb::arg("file") = nb::none(),
- nb::arg("binary") = false,
- R"(
- Prints the assembly form of the operation to a file like object.
-
- Args:
- state: `AsmState` capturing the operation numbering and flags.
- file: Optional file like object to write to. Defaults to sys.stdout.
- binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
- .def("print",
- nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
- bool, bool, bool, bool, bool, bool, nb::object,
- bool, bool>(&PyOperationBase::print),
- // Careful: Lots of arguments must match up with print method.
- nb::arg("large_elements_limit") = nb::none(),
- nb::arg("large_resource_limit") = nb::none(),
- nb::arg("enable_debug_info") = false,
- nb::arg("pretty_debug_info") = false,
- nb::arg("print_generic_op_form") = false,
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
- nb::arg("binary") = false, nb::arg("skip_regions") = false,
- R"(
- Prints the assembly form of the operation to a file like object.
-
- Args:
- large_elements_limit: Whether to elide elements attributes above this
- number of elements. Defaults to None (no limit).
- large_resource_limit: Whether to elide resource attributes above this
- number of characters. Defaults to None (no limit). If large_elements_limit
- is set and this is None, the behavior will be to use large_elements_limit
- as large_resource_limit.
- enable_debug_info: Whether to print debug/location information. Defaults
- to False.
- pretty_debug_info: Whether to format debug information for easier reading
- by a human (warning: the result is unparseable). Defaults to False.
- print_generic_op_form: Whether to print the generic assembly forms of all
- ops. Defaults to False.
- use_local_scope: Whether to print in a way that is more optimized for
- multi-threaded access but may not be consistent with how the overall
- module prints.
- use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
- prefixes for the SSA identifiers. Defaults to False.
- assume_verified: By default, if not printing generic form, the verifier
- will be run and if it fails, generic form will be printed with a comment
- about failed verification. While a reasonable default for interactive use,
- for systematic use, it is often better for the caller to verify explicitly
- and report failures in a more robust fashion. Set this to True if doing this
- in order to avoid running a redundant verification. If the IR is actually
- invalid, behavior is undefined.
- file: The file like object to write to. Defaults to sys.stdout.
- binary: Whether to write bytes (True) or str (False). Defaults to False.
- skip_regions: Whether to skip printing regions. Defaults to False.)")
- .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
- nb::arg("desired_version") = nb::none(),
- R"(
- Write the bytecode form of the operation to a file like object.
-
- Args:
- file: The file like object to write to.
- desired_version: Optional version of bytecode to emit.
- Returns:
- The bytecode writer status.)")
- .def("get_asm", &PyOperationBase::getAsm,
- // Careful: Lots of arguments must match up with get_asm method.
- nb::arg("binary") = false,
- nb::arg("large_elements_limit") = nb::none(),
- nb::arg("large_resource_limit") = nb::none(),
- nb::arg("enable_debug_info") = false,
- nb::arg("pretty_debug_info") = false,
- nb::arg("print_generic_op_form") = false,
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
- R"(
- Gets the assembly form of the operation with all options available.
-
- Args:
- binary: Whether to return a bytes (True) or str (False) object. Defaults to
- False.
- ... others ...: See the print() method for common keyword arguments for
- configuring the printout.
- Returns:
- Either a bytes or str object, depending on the setting of the `binary`
- argument.)")
- .def("verify", &PyOperationBase::verify,
- "Verify the operation. Raises MLIRError if verification fails, and "
- "returns true otherwise.")
- .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
- "Puts self immediately after the other operation in its parent "
- "block.")
- .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
- "Puts self immediately before the other operation in its parent "
- "block.")
- .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
- nb::arg("other"),
- R"(
- Checks if this operation is before another in the same block.
-
- Args:
- other: Another operation in the same parent block.
-
- Returns:
- True if this operation is before `other` in the operation list of the parent block.)")
- .def(
- "clone",
- [](PyOperationBase &self,
- const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
- return self.getOperation().clone(ip);
- },
- nb::arg("ip") = nb::none(),
- R"(
- Creates a deep copy of the operation.
-
- Args:
- ip: Optional insertion point where the cloned operation should be inserted.
- If None, the current insertion point is used. If False, the operation
- remains detached.
-
- Returns:
- A new Operation that is a clone of this operation.)")
- .def(
- "detach_from_parent",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
- PyOperation &operation = self.getOperation();
- operation.checkValid();
- if (!operation.isAttached())
- throw nb::value_error("Detached operation has no parent.");
-
- operation.detachFromParent();
- return operation.createOpView();
- },
- "Detaches the operation from its parent block.")
- .def_prop_ro(
- "attached",
- [](PyOperationBase &self) {
- PyOperation &operation = self.getOperation();
- operation.checkValid();
- return operation.isAttached();
- },
- "Reports if the operation is attached to its parent block.")
- .def(
- "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
- R"(
- Erases the operation and frees its memory.
-
- Note:
- After erasing, any Python references to the operation become invalid.)")
- .def("walk", &PyOperationBase::walk, nb::arg("callback"),
- nb::arg("walk_order") = MlirWalkPostOrder,
- // clang-format off
- nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
- // clang-format on
- R"(
- Walks the operation tree with a callback function.
-
- Args:
- callback: A callable that takes an Operation and returns a WalkResult.
- walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
-
- nb::class_<PyOperation, PyOperationBase>(m, "Operation")
- .def_static(
- "create",
- [](std::string_view name,
- std::optional<std::vector<PyType *>> results,
- std::optional<std::vector<PyValue *>> operands,
- std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors, int regions,
- const std::optional<PyLocation> &location,
- const nb::object &maybeIp,
- bool inferType) -> nb::typed<nb::object, PyOperation> {
- // Unpack/validate operands.
- llvm::SmallVector<MlirValue, 4> mlirOperands;
- if (operands) {
- mlirOperands.reserve(operands->size());
- for (PyValue *operand : *operands) {
- if (!operand)
- throw nb::value_error("operand value cannot be None");
- mlirOperands.push_back(operand->get());
- }
- }
-
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- return PyOperation::create(name, results, mlirOperands, attributes,
- successors, regions, pyLoc, maybeIp,
- inferType);
- },
- nb::arg("name"), nb::arg("results") = nb::none(),
- nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
- nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
- nb::arg("infer_type") = false,
- R"(
- Creates a new operation.
-
- Args:
- name: Operation name (e.g. `dialect.operation`).
- results: Optional sequence of Type representing op result types.
- operands: Optional operands of the operation.
- attributes: Optional Dict of {str: Attribute}.
- successors: Optional List of Block for the operation's successors.
- regions: Number of regions to create (default = 0).
- location: Optional Location object (defaults to resolve from context manager).
- ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
- infer_type: Whether to infer result types (default = False).
- Returns:
- A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
- .def_static(
- "parse",
- [](const std::string &sourceStr, const std::string &sourceName,
- DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyOpView> {
- return PyOperation::parse(context->getRef(), sourceStr, sourceName)
- ->createOpView();
- },
- nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
- nb::arg("context") = nb::none(),
- "Parses an operation. Supports both text assembly format and binary "
- "bytecode format.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule,
- "Gets a capsule wrapping the MlirOperation.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyOperation::createFromCapsule,
- "Creates an Operation from a capsule wrapping MlirOperation.")
- .def_prop_ro(
- "operation",
- [](nb::object self) -> nb::typed<nb::object, PyOperation> {
- return self;
- },
- "Returns self (the operation).")
- .def_prop_ro(
- "opview",
- [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
- return self.createOpView();
- },
- R"(
- Returns an OpView of this operation.
-
- Note:
- If the operation has a registered and loaded dialect then this OpView will
- be concrete wrapper class.)")
- .def_prop_ro("block", &PyOperation::getBlock,
- "Returns the block containing this operation.")
- .def_prop_ro(
- "successors",
- [](PyOperationBase &self) {
- return PyOpSuccessors(self.getOperation().getRef());
- },
- "Returns the list of Operation successors.")
- .def("_set_invalid", &PyOperation::setInvalid,
- "Invalidate the operation.");
-
- auto opViewClass =
- nb::class_<PyOpView, PyOperationBase>(m, "OpView")
- .def(nb::init<nb::typed<nb::object, PyOperation>>(),
- nb::arg("operation"))
- .def(
- "__init__",
- [](PyOpView *self, std::string_view name,
- std::tuple<int, bool> opRegionSpec,
- nb::object operandSegmentSpecObj,
- nb::object resultSegmentSpecObj,
- std::optional<nb::list> resultTypeList, nb::list operandList,
- std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions,
- const std::optional<PyLocation> &location,
- const nb::object &maybeIp) {
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- new (self) PyOpView(PyOpView::buildGeneric(
- name, opRegionSpec, operandSegmentSpecObj,
- resultSegmentSpecObj, resultTypeList, operandList,
- attributes, successors, regions, pyLoc, maybeIp));
- },
- nb::arg("name"), nb::arg("opRegionSpec"),
- nb::arg("operandSegmentSpecObj") = nb::none(),
- nb::arg("resultSegmentSpecObj") = nb::none(),
- nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
- nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(),
- nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
- nb::arg("ip") = nb::none())
- .def_prop_ro(
- "operation",
- [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
- return self.getOperationObject();
- })
- .def_prop_ro("opview",
- [](nb::object self) -> nb::typed<nb::object, PyOpView> {
- return self;
- })
- .def(
- "__str__",
- [](PyOpView &self) { return nb::str(self.getOperationObject()); })
- .def_prop_ro(
- "successors",
- [](PyOperationBase &self) {
- return PyOpSuccessors(self.getOperation().getRef());
- },
- "Returns the list of Operation successors.")
- .def(
- "_set_invalid",
- [](PyOpView &self) { self.getOperation().setInvalid(); },
- "Invalidate the operation.");
- opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
- opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
- opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
- // It is faster to pass the operation_name, ods_regions, and
- // ods_operand_segments/ods_result_segments as arguments to the constructor,
- // rather than to access them as attributes.
- opViewClass.attr("build_generic") = classmethod(
- [](nb::handle cls, std::optional<nb::list> resultTypeList,
- nb::list operandList, std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, std::optional<PyLocation> location,
- const nb::object &maybeIp) {
- std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
- std::tuple<int, bool> opRegionSpec =
- nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
- nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
- nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
- resultSegmentSpec, resultTypeList,
- operandList, attributes, successors,
- regions, pyLoc, maybeIp);
- },
- nb::arg("cls"), nb::arg("results") = nb::none(),
- nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
- nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
- "Builds a specific, generated OpView based on class level attributes.");
- opViewClass.attr("parse") = classmethod(
- [](const nb::object &cls, const std::string &sourceStr,
- const std::string &sourceName,
- DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
- PyOperationRef parsed =
- PyOperation::parse(context->getRef(), sourceStr, sourceName);
-
- // Check if the expected operation was parsed, and cast to to the
- // appropriate `OpView` subclass if successful.
- // NOTE: This accesses attributes that have been automatically added to
- // `OpView` subclasses, and is not intended to be used on `OpView`
- // directly.
- std::string clsOpName =
- nb::cast<std::string>(cls.attr("OPERATION_NAME"));
- MlirStringRef identifier =
- mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
- std::string_view parsedOpName(identifier.data, identifier.length);
- if (clsOpName != parsedOpName)
- throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
- parsedOpName + "'");
- return PyOpView::constructDerived(cls, parsed.getObject());
- },
- nb::arg("cls"), nb::arg("source"), nb::kw_only(),
- nb::arg("source_name") = "", nb::arg("context") = nb::none(),
- "Parses a specific, generated OpView based on class level attributes.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyRegion.
- //----------------------------------------------------------------------------
- nb::class_<PyRegion>(m, "Region")
- .def_prop_ro(
- "blocks",
- [](PyRegion &self) {
- return PyBlockList(self.getParentOperation(), self.get());
- },
- "Returns a forward-optimized sequence of blocks.")
- .def_prop_ro(
- "owner",
- [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
- return self.getParentOperation()->createOpView();
- },
- "Returns the operation owning this region.")
- .def(
- "__iter__",
- [](PyRegion &self) {
- self.checkValid();
- MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
- return PyBlockIterator(self.getParentOperation(), firstBlock);
- },
- "Iterates over blocks in the region.")
- .def(
- "__eq__",
- [](PyRegion &self, PyRegion &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two regions for pointer equality.")
- .def(
- "__eq__", [](PyRegion &self, nb::object &other) { return false; },
- "Compares region with non-region object (always returns False).");
-
- //----------------------------------------------------------------------------
- // Mapping of PyBlock.
- //----------------------------------------------------------------------------
- nb::class_<PyBlock>(m, "Block")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule,
- "Gets a capsule wrapping the MlirBlock.")
- .def_prop_ro(
- "owner",
- [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
- return self.getParentOperation()->createOpView();
- },
- "Returns the owning operation of this block.")
- .def_prop_ro(
- "region",
- [](PyBlock &self) {
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- return PyRegion(self.getParentOperation(), region);
- },
- "Returns the owning region of this block.")
- .def_prop_ro(
- "arguments",
- [](PyBlock &self) {
- return PyBlockArgumentList(self.getParentOperation(), self.get());
- },
- "Returns a list of block arguments.")
- .def(
- "add_argument",
- [](PyBlock &self, const PyType &type, const PyLocation &loc) {
- return PyBlockArgument(self.getParentOperation(),
- mlirBlockAddArgument(self.get(), type, loc));
- },
- "type"_a, "loc"_a,
- R"(
- Appends an argument of the specified type to the block.
-
- Args:
- type: The type of the argument to add.
- loc: The source location for the argument.
-
- Returns:
- The newly added block argument.)")
- .def(
- "erase_argument",
- [](PyBlock &self, unsigned index) {
- return mlirBlockEraseArgument(self.get(), index);
- },
- nb::arg("index"),
- R"(
- Erases the argument at the specified index.
-
- Args:
- index: The index of the argument to erase.)")
- .def_prop_ro(
- "operations",
- [](PyBlock &self) {
- return PyOperationList(self.getParentOperation(), self.get());
- },
- "Returns a forward-optimized sequence of operations.")
- .def_static(
- "create_at_start",
- [](PyRegion &parent, const nb::sequence &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- parent.checkValid();
- MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
- mlirRegionInsertOwnedBlock(parent, 0, block);
- return PyBlock(parent.getParentOperation(), block);
- },
- nb::arg("parent"), nb::arg("arg_types") = nb::list(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block at the beginning of the given "
- "region (with given argument types and locations).")
- .def(
- "append_to",
- [](PyBlock &self, PyRegion ®ion) {
- MlirBlock b = self.get();
- if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
- mlirBlockDetach(b);
- mlirRegionAppendOwnedBlock(region.get(), b);
- },
- nb::arg("region"),
- R"(
- Appends this block to a region.
-
- Transfers ownership if the block is currently owned by another region.
-
- Args:
- region: The region to append the block to.)")
- .def(
- "create_before",
- [](PyBlock &self, const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- self.checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
- return PyBlock(self.getParentOperation(), block);
- },
- nb::arg("arg_types"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block before this block "
- "(with given argument types and locations).")
- .def(
- "create_after",
- [](PyBlock &self, const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- self.checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
- return PyBlock(self.getParentOperation(), block);
- },
- nb::arg("arg_types"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block after this block "
- "(with given argument types and locations).")
- .def(
- "__iter__",
- [](PyBlock &self) {
- self.checkValid();
- MlirOperation firstOperation =
- mlirBlockGetFirstOperation(self.get());
- return PyOperationIterator(self.getParentOperation(),
- firstOperation);
- },
- "Iterates over operations in the block.")
- .def(
- "__eq__",
- [](PyBlock &self, PyBlock &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two blocks for pointer equality.")
- .def(
- "__eq__", [](PyBlock &self, nb::object &other) { return false; },
- "Compares block with non-block object (always returns False).")
- .def(
- "__hash__",
- [](PyBlock &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the block.")
- .def(
- "__str__",
- [](PyBlock &self) {
- self.checkValid();
- PyPrintAccumulator printAccum;
- mlirBlockPrint(self.get(), printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the block.")
- .def(
- "append",
- [](PyBlock &self, PyOperationBase &operation) {
- if (operation.getOperation().isAttached())
- operation.getOperation().detachFromParent();
-
- MlirOperation mlirOperation = operation.getOperation().get();
- mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
- operation.getOperation().setAttached(
- self.getParentOperation().getObject());
- },
- nb::arg("operation"),
- R"(
- Appends an operation to this block.
-
- If the operation is currently in another block, it will be moved.
-
- Args:
- operation: The operation to append to the block.)")
- .def_prop_ro(
- "successors",
- [](PyBlock &self) {
- return PyBlockSuccessors(self, self.getParentOperation());
- },
- "Returns the list of Block successors.")
- .def_prop_ro(
- "predecessors",
- [](PyBlock &self) {
- return PyBlockPredecessors(self, self.getParentOperation());
- },
- "Returns the list of Block predecessors.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyInsertionPoint.
- //----------------------------------------------------------------------------
-
- nb::class_<PyInsertionPoint>(m, "InsertionPoint")
- .def(nb::init<PyBlock &>(), nb::arg("block"),
- "Inserts after the last operation but still inside the block.")
- .def("__enter__", &PyInsertionPoint::contextEnter,
- "Enters the insertion point as a context manager.")
- .def("__exit__", &PyInsertionPoint::contextExit,
- nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none(),
- "Exits the insertion point context manager.")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/) {
- auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
- if (!ip)
- throw nb::value_error("No current InsertionPoint");
- return ip;
- },
- nb::sig("def current(/) -> InsertionPoint"),
- "Gets the InsertionPoint bound to the current thread or raises "
- "ValueError if none has been set.")
- .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
- "Inserts before a referenced operation.")
- .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
- nb::arg("block"),
- R"(
- Creates an insertion point at the beginning of a block.
-
- Args:
- block: The block at whose beginning operations should be inserted.
-
- Returns:
- An InsertionPoint at the block's beginning.)")
- .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
- nb::arg("block"),
- R"(
- Creates an insertion point before a block's terminator.
-
- Args:
- block: The block whose terminator to insert before.
-
- Returns:
- An InsertionPoint before the terminator.
-
- Raises:
- ValueError: If the block has no terminator.)")
- .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
- R"(
- Creates an insertion point immediately after an operation.
-
- Args:
- operation: The operation after which to insert.
-
- Returns:
- An InsertionPoint after the operation.)")
- .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
- R"(
- Inserts an operation at this insertion point.
-
- Args:
- operation: The operation to insert.)")
- .def_prop_ro(
- "block", [](PyInsertionPoint &self) { return self.getBlock(); },
- "Returns the block that this `InsertionPoint` points to.")
- .def_prop_ro(
- "ref_operation",
- [](PyInsertionPoint &self)
- -> std::optional<nb::typed<nb::object, PyOperation>> {
- auto refOperation = self.getRefOperation();
- if (refOperation)
- return refOperation->getObject();
- return {};
- },
- "The reference operation before which new operations are "
- "inserted, or None if the insertion point is at the end of "
- "the block.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyAttribute.
- //----------------------------------------------------------------------------
- nb::class_<PyAttribute>(m, "Attribute")
- // Delegate to the PyAttribute copy constructor, which will also lifetime
- // extend the backing context which owns the MlirAttribute.
- .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
- "Casts the passed attribute to the generic `Attribute`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule,
- "Gets a capsule wrapping the MlirAttribute.")
- .def_static(
- MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule,
- "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
- .def_static(
- "parse",
- [](const std::string &attrSpec, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyAttribute> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute attr = mlirAttributeParseGet(
- context->get(), toMlirStringRef(attrSpec));
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Unable to parse attribute", errors.take());
- return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- "Parses an attribute from an assembly form. Raises an `MLIRError` on "
- "failure.")
- .def_prop_ro(
- "context",
- [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Attribute`.")
- .def_prop_ro(
- "type",
- [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirAttributeGetType(self))
- .maybeDownCast();
- },
- "Returns the type of the `Attribute`.")
- .def(
- "get_named",
- [](PyAttribute &self, std::string name) {
- return PyNamedAttribute(self, std::move(name));
- },
- nb::keep_alive<0, 1>(),
- R"(
- Binds a name to the attribute, creating a `NamedAttribute`.
-
- Args:
- name: The name to bind to the `Attribute`.
-
- Returns:
- A `NamedAttribute` with the given name and this attribute.)")
- .def(
- "__eq__",
- [](PyAttribute &self, PyAttribute &other) { return self == other; },
- "Compares two attributes for equality.")
- .def(
- "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
- "Compares attribute with non-attribute object (always returns "
- "False).")
- .def(
- "__hash__",
- [](PyAttribute &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the attribute.")
- .def(
- "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
- kDumpDocstring)
- .def(
- "__str__",
- [](PyAttribute &self) {
- PyPrintAccumulator printAccum;
- mlirAttributePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the Attribute.")
- .def(
- "__repr__",
- [](PyAttribute &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, attribute values are generally considered useful and
- // are printed. This may need to be re-evaluated if debug dumps end
- // up being excessive.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Attribute(");
- mlirAttributePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the attribute.")
- .def_prop_ro(
- "typeid",
- [](PyAttribute &self) {
- MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
- assert(!mlirTypeIDIsNull(mlirTypeID) &&
- "mlirTypeID was expected to be non-null.");
- return PyTypeID(mlirTypeID);
- },
- "Returns the `TypeID` of the attribute.")
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
- return self.maybeDownCast();
- },
- "Downcasts the attribute to a more specific attribute if possible.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyNamedAttribute
- //----------------------------------------------------------------------------
- nb::class_<PyNamedAttribute>(m, "NamedAttribute")
- .def(
- "__repr__",
- [](PyNamedAttribute &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("NamedAttribute(");
- printAccum.parts.append(
- nb::str(mlirIdentifierStr(self.namedAttr.name).data,
- mlirIdentifierStr(self.namedAttr.name).length));
- printAccum.parts.append("=");
- mlirAttributePrint(self.namedAttr.attribute,
- printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the named attribute.")
- .def_prop_ro(
- "name",
- [](PyNamedAttribute &self) {
- return mlirIdentifierStr(self.namedAttr.name);
- },
- "The name of the `NamedAttribute` binding.")
- .def_prop_ro(
- "attr",
- [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
- nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
- "The underlying generic attribute of the `NamedAttribute` binding.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyType.
- //----------------------------------------------------------------------------
- nb::class_<PyType>(m, "Type")
- // Delegate to the PyType copy constructor, which will also lifetime
- // extend the backing context which owns the MlirType.
- .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
- "Casts the passed type to the generic `Type`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule,
- "Gets a capsule wrapping the `MlirType`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule,
- "Creates a Type from a capsule wrapping `MlirType`.")
- .def_static(
- "parse",
- [](std::string typeSpec,
- DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType type =
- mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
- if (mlirTypeIsNull(type))
- throw MLIRError("Unable to parse type", errors.take());
- return PyType(context.get()->getRef(), type).maybeDownCast();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- R"(
- Parses the assembly form of a type.
-
- Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
-
- See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
- .def_prop_ro(
- "context",
- [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Type`.")
- .def(
- "__eq__", [](PyType &self, PyType &other) { return self == other; },
- "Compares two types for equality.")
- .def(
- "__eq__", [](PyType &self, nb::object &other) { return false; },
- nb::arg("other").none(),
- "Compares type with non-type object (always returns False).")
- .def(
- "__hash__",
- [](PyType &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the `Type`.")
- .def(
- "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
- .def(
- "__str__",
- [](PyType &self) {
- PyPrintAccumulator printAccum;
- mlirTypePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the `Type`.")
- .def(
- "__repr__",
- [](PyType &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, types are an exception as they typically have compact
- // assembly forms and printing them is useful.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Type(");
- mlirTypePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the `Type`.")
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyType &self) -> nb::typed<nb::object, PyType> {
- return self.maybeDownCast();
- },
- "Downcasts the Type to a more specific `Type` if possible.")
- .def_prop_ro(
- "typeid",
- [](PyType &self) {
- MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
- if (!mlirTypeIDIsNull(mlirTypeID))
- return PyTypeID(mlirTypeID);
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
- throw nb::value_error(
- (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
- },
- "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
- "`Type` has no "
- "`TypeID`.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyTypeID.
- //----------------------------------------------------------------------------
- nb::class_<PyTypeID>(m, "TypeID")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule,
- "Gets a capsule wrapping the `MlirTypeID`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule,
- "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
- // Note, this tests whether the underlying TypeIDs are the same,
- // not whether the wrapper MlirTypeIDs are the same, nor whether
- // the Python objects are the same (i.e., PyTypeID is a value type).
- .def(
- "__eq__",
- [](PyTypeID &self, PyTypeID &other) { return self == other; },
- "Compares two `TypeID`s for equality.")
- .def(
- "__eq__",
- [](PyTypeID &self, const nb::object &other) { return false; },
- "Compares TypeID with non-TypeID object (always returns False).")
- // Note, this gives the hash value of the underlying TypeID, not the
- // hash value of the Python object, nor the hash value of the
- // MlirTypeID wrapper.
- .def(
- "__hash__",
- [](PyTypeID &self) {
- return static_cast<size_t>(mlirTypeIDHashValue(self));
- },
- "Returns the hash value of the `TypeID`.");
-
- //----------------------------------------------------------------------------
- // Mapping of Value.
- //----------------------------------------------------------------------------
- m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
-
- nb::class_<PyValue>(m, "Value", nb::is_generic(),
- nb::sig("class Value(Generic[_T])"))
- .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
- "Creates a Value reference from another `Value`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
- "Gets a capsule wrapping the `MlirValue`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule,
- "Creates a `Value` from a capsule wrapping `MlirValue`.")
- .def_prop_ro(
- "context",
- [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getParentOperation()->getContext().getObject();
- },
- "Context in which the value lives.")
- .def(
- "dump", [](PyValue &self) { mlirValueDump(self.get()); },
- kDumpDocstring)
- .def_prop_ro(
- "owner",
- [](PyValue &self) -> nb::object {
- MlirValue v = self.get();
- if (mlirValueIsAOpResult(v)) {
- assert(mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match "
- "that in "
- "the IR");
- return self.getParentOperation().getObject();
- }
-
- if (mlirValueIsABlockArgument(v)) {
- MlirBlock block = mlirBlockArgumentGetOwner(self.get());
- return nb::cast(PyBlock(self.getParentOperation(), block));
- }
-
- assert(false && "Value must be a block argument or an op result");
- return nb::none();
- },
- "Returns the owner of the value (`Operation` for results, `Block` "
- "for "
- "arguments).")
- .def_prop_ro(
- "uses",
- [](PyValue &self) {
- return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
- },
- "Returns an iterator over uses of this value.")
- .def(
- "__eq__",
- [](PyValue &self, PyValue &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two values for pointer equality.")
- .def(
- "__eq__", [](PyValue &self, nb::object other) { return false; },
- "Compares value with non-value object (always returns False).")
- .def(
- "__hash__",
- [](PyValue &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the value.")
- .def(
- "__str__",
- [](PyValue &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Value(");
- mlirValuePrint(self.get(), printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- R"(
- Returns the string form of the value.
-
- If the value is a block argument, this is the assembly form of its type and the
- position in the argument list. If the value is an operation result, this is
- equivalent to printing the operation that produced it.
- )")
- .def(
- "get_name",
- [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
- PyPrintAccumulator printAccum;
- MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- if (useNameLocAsPrefix)
- mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
- MlirAsmState valueState =
- mlirAsmStateCreateForValue(self.get(), flags);
- mlirValuePrintAsOperand(self.get(), valueState,
- printAccum.getCallback(),
- printAccum.getUserData());
- mlirOpPrintingFlagsDestroy(flags);
- mlirAsmStateDestroy(valueState);
- return printAccum.join();
- },
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- R"(
- Returns the string form of value as an operand.
-
- Args:
- use_local_scope: Whether to use local scope for naming.
- use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
-
- Returns:
- The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
- .def(
- "get_name",
- [](PyValue &self, PyAsmState &state) {
- PyPrintAccumulator printAccum;
- MlirAsmState valueState = state.get();
- mlirValuePrintAsOperand(self.get(), valueState,
- printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- nb::arg("state"),
- "Returns the string form of value as an operand (i.e., the ValueID).")
- .def_prop_ro(
- "type",
- [](PyValue &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getParentOperation()->getContext(),
- mlirValueGetType(self.get()))
- .maybeDownCast();
- },
- "Returns the type of the value.")
- .def(
- "set_type",
- [](PyValue &self, const PyType &type) {
- mlirValueSetType(self.get(), type);
- },
- nb::arg("type"), "Sets the type of the value.",
- nb::sig("def set_type(self, type: _T)"))
- .def(
- "replace_all_uses_with",
- [](PyValue &self, PyValue &with) {
- mlirValueReplaceAllUsesOfWith(self.get(), with.get());
- },
- "Replace all uses of value with the new value, updating anything in "
- "the IR that uses `self` to use the other value instead.")
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, PyOperation &exception) {
- MlirOperation exceptedUser = exception.get();
- mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, const nb::list &exceptions) {
- // Convert Python list to a SmallVector of MlirOperations
- llvm::SmallVector<MlirOperation> exceptionOps;
- for (nb::handle exception : exceptions) {
- exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
- }
-
- mlirValueReplaceAllUsesExcept(
- self, with, static_cast<intptr_t>(exceptionOps.size()),
- exceptionOps.data());
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, PyOperation &exception) {
- MlirOperation exceptedUser = exception.get();
- mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with,
- std::vector<PyOperation> &exceptions) {
- // Convert Python list to a SmallVector of MlirOperations
- llvm::SmallVector<MlirOperation> exceptionOps;
- for (PyOperation &exception : exceptions)
- exceptionOps.push_back(exception);
- mlirValueReplaceAllUsesExcept(
- self, with, static_cast<intptr_t>(exceptionOps.size()),
- exceptionOps.data());
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyValue &self) -> nb::typed<nb::object, PyValue> {
- return self.maybeDownCast();
- },
- "Downcasts the `Value` to a more specific kind if possible.")
- .def_prop_ro(
- "location",
- [](MlirValue self) {
- return PyLocation(
- PyMlirContext::forContext(mlirValueGetContext(self)),
- mlirValueGetLocation(self));
- },
- "Returns the source location of the value.");
-
- PyBlockArgument::bind(m);
- PyOpResult::bind(m);
- PyOpOperand::bind(m);
-
- nb::class_<PyAsmState>(m, "AsmState")
- .def(nb::init<PyValue &, bool>(), nb::arg("value"),
- nb::arg("use_local_scope") = false,
- R"(
- Creates an `AsmState` for consistent SSA value naming.
-
- Args:
- value: The value to create state for.
- use_local_scope: Whether to use local scope for naming.)")
- .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
- nb::arg("use_local_scope") = false,
- R"(
- Creates an AsmState for consistent SSA value naming.
-
- Args:
- op: The operation to create state for.
- use_local_scope: Whether to use local scope for naming.)");
-
- //----------------------------------------------------------------------------
- // Mapping of SymbolTable.
- //----------------------------------------------------------------------------
- nb::class_<PySymbolTable>(m, "SymbolTable")
- .def(nb::init<PyOperationBase &>(),
- R"(
- Creates a symbol table for an operation.
-
- Args:
- operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
-
- Raises:
- TypeError: If the operation is not a symbol table.)")
- .def(
- "__getitem__",
- [](PySymbolTable &self,
- const std::string &name) -> nb::typed<nb::object, PyOpView> {
- return self.dunderGetItem(name);
- },
- R"(
- Looks up a symbol by name in the symbol table.
-
- Args:
- name: The name of the symbol to look up.
-
- Returns:
- The operation defining the symbol.
-
- Raises:
- KeyError: If the symbol is not found.)")
- .def("insert", &PySymbolTable::insert, nb::arg("operation"),
- R"(
- Inserts a symbol operation into the symbol table.
-
- Args:
- operation: An operation with a symbol name to insert.
-
- Returns:
- The symbol name attribute of the inserted operation.
-
- Raises:
- ValueError: If the operation does not have a symbol name.)")
- .def("erase", &PySymbolTable::erase, nb::arg("operation"),
- R"(
- Erases a symbol operation from the symbol table.
-
- Args:
- operation: The symbol operation to erase.
-
- Note:
- The operation is also erased from the IR and invalidated.)")
- .def("__delitem__", &PySymbolTable::dunderDel,
- "Deletes a symbol by name from the symbol table.")
- .def(
- "__contains__",
- [](PySymbolTable &table, const std::string &name) {
- return !mlirOperationIsNull(mlirSymbolTableLookup(
- table, mlirStringRefCreate(name.data(), name.length())));
- },
- "Checks if a symbol with the given name exists in the table.")
- // Static helpers.
- .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
- nb::arg("symbol"), nb::arg("name"),
- "Sets the symbol name for a symbol operation.")
- .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
- nb::arg("symbol"),
- "Gets the symbol name from a symbol operation.")
- .def_static("get_visibility", &PySymbolTable::getVisibility,
- nb::arg("symbol"),
- "Gets the visibility attribute of a symbol operation.")
- .def_static("set_visibility", &PySymbolTable::setVisibility,
- nb::arg("symbol"), nb::arg("visibility"),
- "Sets the visibility attribute of a symbol operation.")
- .def_static("replace_all_symbol_uses",
- &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
- nb::arg("new_symbol"), nb::arg("from_op"),
- "Replaces all uses of a symbol with a new symbol name within "
- "the given operation.")
- .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
- nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
- nb::arg("callback"),
- "Walks symbol tables starting from an operation with a "
- "callback function.");
-
- // Container bindings.
- PyBlockArgumentList::bind(m);
- PyBlockIterator::bind(m);
- PyBlockList::bind(m);
- PyBlockSuccessors::bind(m);
- PyBlockPredecessors::bind(m);
- PyOperationIterator::bind(m);
- PyOperationList::bind(m);
- PyOpAttributeMap::bind(m);
- PyOpOperandIterator::bind(m);
- PyOpOperandList::bind(m);
- PyOpResultList::bind(m);
- PyOpSuccessors::bind(m);
- PyRegionIterator::bind(m);
- PyRegionList::bind(m);
-
- // Debug bindings.
- PyGlobalDebugFlag::bind(m);
-
- // Attribute builder getter.
- PyAttrBuilderMap::bind(m);
-
+void registerMLIRErrorInIRCore() {
nb::register_exception_translator([](const std::exception_ptr &p,
void *payload) {
- // We can't define exceptions with custom fields through pybind, so instead
- // the exception class is defined in python and imported here.
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
try {
if (p)
std::rethrow_exception(p);
@@ -4963,3 +1687,4 @@ void mlir::python::populateIRCore(nb::module_ &m) {
}
});
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 31d4798ffb906..f1e494c375523 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,7 +12,7 @@
#include <utility>
#include <vector>
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 34c5b8dd86a66..294ab91a059e2 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -7,14 +7,13 @@
//===----------------------------------------------------------------------===//
// clang-format off
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/IRTypes.h"
// clang-format on
#include <optional>
-#include "IRModule.h"
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
@@ -1144,7 +1143,8 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
} // namespace
-void mlir::python::populateIRTypes(nb::module_ &m) {
+namespace mlir::python {
+void populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
PyIndexType::bind(m);
@@ -1175,4 +1175,18 @@ void mlir::python::populateIRTypes(nb::module_ &m) {
PyTupleType::bind(m);
PyFunctionType::bind(m);
PyOpaqueType::bind(m);
+ nb::register_exception_translator([](const std::exception_ptr &p,
+ void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+});
+}
}
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index ba767ad6692cf..686c55ee1e6a8 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,18 +6,2275 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "Pass.h"
#include "Rewrite.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
using namespace mlir::python;
+static const char kModuleParseDocstring[] =
+ R"(Parses a module's assembly format from a string.
+
+Returns a new MlirModule or raises an MLIRError if the parsing fails.
+
+See also: https://mlir.llvm.org/docs/LangRef/
+)";
+
+static const char kDumpDocstring[] =
+ "Dumps a debug representation of the object to stderr.";
+
+static const char kValueReplaceAllUsesExceptDocstring[] =
+ R"(Replace all uses of this value with the `with` value, except for those
+in `exceptions`. `exceptions` can be either a single operation or a list of
+operations.
+)";
+
+namespace {
+// see
+// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
+
+#ifndef _Py_CAST
+#define _Py_CAST(type, expr) ((type)(expr))
+#endif
+
+// Static inline functions should use _Py_NULL rather than using directly NULL
+// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
+// _Py_NULL is defined as nullptr.
+#ifndef _Py_NULL
+#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
+ (defined(__cplusplus) && __cplusplus >= 201103)
+#define _Py_NULL nullptr
+#else
+#define _Py_NULL NULL
+#endif
+#endif
+
+// Python 3.10.0a3
+#if PY_VERSION_HEX < 0x030A00A3
+
+// bpo-42262 added Py_XNewRef()
+#if !defined(Py_XNewRef)
+[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
+ Py_XINCREF(obj);
+ return obj;
+}
+#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
+#endif
+
+// bpo-42262 added Py_NewRef()
+#if !defined(Py_NewRef)
+[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
+ Py_INCREF(obj);
+ return obj;
+}
+#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
+#endif
+
+#endif // Python 3.10.0a3
+
+// Python 3.9.0b1
+#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
+
+// bpo-40429 added PyThreadState_GetFrame()
+PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
+ assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
+}
+
+// bpo-40421 added PyFrame_GetBack()
+PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
+}
+
+// bpo-40421 added PyFrame_GetCode()
+PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
+ return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
+}
+
+#endif // Python 3.9.0b1
+
+MlirLocation tracebackToLocation(MlirContext ctx) {
+ size_t framesLimit =
+ PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
+ // Use a thread_local here to avoid requiring a large amount of space.
+ thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
+ frames;
+ size_t count = 0;
+
+ nb::gil_scoped_acquire acquire;
+ PyThreadState *tstate = PyThreadState_GET();
+ PyFrameObject *next;
+ PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
+ // In the increment expression:
+ // 1. get the next prev frame;
+ // 2. decrement the ref count on the current frame (in order that it can get
+ // gc'd, along with any objects in its closure and etc);
+ // 3. set current = next.
+ for (; pyFrame != nullptr && count < framesLimit;
+ next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
+ PyCodeObject *code = PyFrame_GetCode(pyFrame);
+ auto fileNameStr =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
+ llvm::StringRef fileName(fileNameStr);
+ if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
+ continue;
+
+ // co_qualname and PyCode_Addr2Location added in py3.11
+#if PY_VERSION_HEX < 0x030B00F0
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
+ llvm::StringRef funcName(name);
+ int startLine = PyFrame_GetLineNumber(pyFrame);
+ MlirLocation loc =
+ mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
+#else
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
+ llvm::StringRef funcName(name);
+ int startLine, startCol, endLine, endCol;
+ int lasti = PyFrame_GetLasti(pyFrame);
+ if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
+ &endCol)) {
+ throw nb::python_error();
+ }
+ MlirLocation loc = mlirLocationFileLineColRangeGet(
+ ctx, wrap(fileName), startLine, startCol, endLine, endCol);
+#endif
+
+ frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
+ ++count;
+ }
+ // When the loop breaks (after the last iter), current frame (if non-null)
+ // is leaked without this.
+ Py_XDECREF(pyFrame);
+
+ if (count == 0)
+ return mlirLocationUnknownGet(ctx);
+
+ MlirLocation callee = frames[0];
+ assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
+ if (count == 1)
+ return callee;
+
+ MlirLocation caller = frames[count - 1];
+ assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
+ for (int i = count - 2; i >= 1; i--)
+ caller = mlirLocationCallSiteGet(frames[i], caller);
+
+ return mlirLocationCallSiteGet(callee, caller);
+}
+
+PyLocation
+maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
+ if (location.has_value())
+ return location.value();
+ if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
+ return DefaultingPyLocation::resolve();
+
+ PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
+ MlirLocation mlirLoc = tracebackToLocation(ctx.get());
+ PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
+ return {ref, mlirLoc};
+}
+} // namespace
+
+//------------------------------------------------------------------------------
+// Populates the core exports of the 'ir' submodule.
+//------------------------------------------------------------------------------
+
+static void populateIRCore(nb::module_ &m) {
+ // disable leak warnings which tend to be false positives.
+ nb::set_leak_warnings(false);
+ //----------------------------------------------------------------------------
+ // Enums.
+ //----------------------------------------------------------------------------
+ nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
+ .value("ERROR", MlirDiagnosticError)
+ .value("WARNING", MlirDiagnosticWarning)
+ .value("NOTE", MlirDiagnosticNote)
+ .value("REMARK", MlirDiagnosticRemark);
+
+ nb::enum_<MlirWalkOrder>(m, "WalkOrder")
+ .value("PRE_ORDER", MlirWalkPreOrder)
+ .value("POST_ORDER", MlirWalkPostOrder);
+
+ nb::enum_<MlirWalkResult>(m, "WalkResult")
+ .value("ADVANCE", MlirWalkResultAdvance)
+ .value("INTERRUPT", MlirWalkResultInterrupt)
+ .value("SKIP", MlirWalkResultSkip);
+
+ //----------------------------------------------------------------------------
+ // Mapping of Diagnostics.
+ //----------------------------------------------------------------------------
+ nb::class_<PyDiagnostic>(m, "Diagnostic")
+ .def_prop_ro("severity", &PyDiagnostic::getSeverity,
+ "Returns the severity of the diagnostic.")
+ .def_prop_ro("location", &PyDiagnostic::getLocation,
+ "Returns the location associated with the diagnostic.")
+ .def_prop_ro("message", &PyDiagnostic::getMessage,
+ "Returns the message text of the diagnostic.")
+ .def_prop_ro("notes", &PyDiagnostic::getNotes,
+ "Returns a tuple of attached note diagnostics.")
+ .def(
+ "__str__",
+ [](PyDiagnostic &self) -> nb::str {
+ if (!self.isValid())
+ return nb::str("<Invalid Diagnostic>");
+ return self.getMessage();
+ },
+ "Returns the diagnostic message as a string.");
+
+ nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
+ .def(
+ "__init__",
+ [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
+ new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
+ },
+ "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
+ .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
+ "The severity level of the diagnostic.")
+ .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
+ "The location associated with the diagnostic.")
+ .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
+ "The message text of the diagnostic.")
+ .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
+ "List of attached note diagnostics.")
+ .def(
+ "__str__",
+ [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
+ "Returns the diagnostic message as a string.");
+
+ nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
+ .def("detach", &PyDiagnosticHandler::detach,
+ "Detaches the diagnostic handler from the context.")
+ .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
+ "Returns True if the handler is attached to a context.")
+ .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
+ "Returns True if an error was encountered during diagnostic "
+ "handling.")
+ .def("__enter__", &PyDiagnosticHandler::contextEnter,
+ "Enters the diagnostic handler as a context manager.")
+ .def("__exit__", &PyDiagnosticHandler::contextExit,
+ nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+ nb::arg("traceback").none(),
+ "Exits the diagnostic handler context manager.");
+
+ // Expose DefaultThreadPool to python
+ nb::class_<PyThreadPool>(m, "ThreadPool")
+ .def(
+ "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
+ "Creates a new thread pool with default concurrency.")
+ .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
+ "Returns the maximum number of threads in the pool.")
+ .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
+ "Returns the raw pointer to the LLVM thread pool as a string.");
+
+ nb::class_<PyMlirContext>(m, "Context")
+ .def(
+ "__init__",
+ [](PyMlirContext &self) {
+ MlirContext context = mlirContextCreateWithThreading(false);
+ new (&self) PyMlirContext(context);
+ },
+ R"(
+ Creates a new MLIR context.
+
+ The context is the top-level container for all MLIR objects. It owns the storage
+ for types, attributes, locations, and other core IR objects. A context can be
+ configured to allow or disallow unregistered dialects and can have dialects
+ loaded on-demand.)")
+ .def_static("_get_live_count", &PyMlirContext::getLiveCount,
+ "Gets the number of live Context objects.")
+ .def(
+ "_get_context_again",
+ [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
+ PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+ return ref.releaseObject();
+ },
+ "Gets another reference to the same context.")
+ .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
+ "Gets the number of live modules owned by this context.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule,
+ "Gets a capsule wrapping the MlirContext.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyMlirContext::createFromCapsule,
+ "Creates a Context from a capsule wrapping MlirContext.")
+ .def("__enter__", &PyMlirContext::contextEnter,
+ "Enters the context as a context manager.")
+ .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
+ nb::arg("exc_value").none(), nb::arg("traceback").none(),
+ "Exits the context manager.")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/)
+ -> std::optional<nb::typed<nb::object, PyMlirContext>> {
+ auto *context = PyThreadContextEntry::getDefaultContext();
+ if (!context)
+ return {};
+ return nb::cast(context);
+ },
+ nb::sig("def current(/) -> Context | None"),
+ "Gets the Context bound to the current thread or returns None if no "
+ "context is set.")
+ .def_prop_ro(
+ "dialects",
+ [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Gets a container for accessing dialects by name.")
+ .def_prop_ro(
+ "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Alias for `dialects`.")
+ .def(
+ "get_dialect_descriptor",
+ [=](PyMlirContext &self, std::string &name) {
+ MlirDialect dialect = mlirContextGetOrLoadDialect(
+ self.get(), {name.data(), name.size()});
+ if (mlirDialectIsNull(dialect)) {
+ throw nb::value_error(
+ (Twine("Dialect '") + name + "' not found").str().c_str());
+ }
+ return PyDialectDescriptor(self.getRef(), dialect);
+ },
+ nb::arg("dialect_name"),
+ "Gets or loads a dialect by name, returning its descriptor object.")
+ .def_prop_rw(
+ "allow_unregistered_dialects",
+ [](PyMlirContext &self) -> bool {
+ return mlirContextGetAllowUnregisteredDialects(self.get());
+ },
+ [](PyMlirContext &self, bool value) {
+ mlirContextSetAllowUnregisteredDialects(self.get(), value);
+ },
+ "Controls whether unregistered dialects are allowed in this context.")
+ .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
+ nb::arg("callback"),
+ "Attaches a diagnostic handler that will receive callbacks.")
+ .def(
+ "enable_multithreading",
+ [](PyMlirContext &self, bool enable) {
+ mlirContextEnableMultithreading(self.get(), enable);
+ },
+ nb::arg("enable"),
+ R"(
+ Enables or disables multi-threading support in the context.
+
+ Args:
+ enable: Whether to enable (True) or disable (False) multi-threading.
+ )")
+ .def(
+ "set_thread_pool",
+ [](PyMlirContext &self, PyThreadPool &pool) {
+ // we should disable multi-threading first before setting
+ // new thread pool otherwise the assert in
+ // MLIRContext::setThreadPool will be raised.
+ mlirContextEnableMultithreading(self.get(), false);
+ mlirContextSetThreadPool(self.get(), pool.get());
+ },
+ R"(
+ Sets a custom thread pool for the context to use.
+
+ Args:
+ pool: A ThreadPool object to use for parallel operations.
+
+ Note:
+ Multi-threading is automatically disabled before setting the thread pool.)")
+ .def(
+ "get_num_threads",
+ [](PyMlirContext &self) {
+ return mlirContextGetNumThreads(self.get());
+ },
+ "Gets the number of threads in the context's thread pool.")
+ .def(
+ "_mlir_thread_pool_ptr",
+ [](PyMlirContext &self) {
+ MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
+ std::stringstream ss;
+ ss << pool.ptr;
+ return ss.str();
+ },
+ "Gets the raw pointer to the LLVM thread pool as a string.")
+ .def(
+ "is_registered_operation",
+ [](PyMlirContext &self, std::string &name) {
+ return mlirContextIsRegisteredOperation(
+ self.get(), MlirStringRef{name.data(), name.size()});
+ },
+ nb::arg("operation_name"),
+ R"(
+ Checks whether an operation with the given name is registered.
+
+ Args:
+ operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
+
+ Returns:
+ True if the operation is registered, False otherwise.)")
+ .def(
+ "append_dialect_registry",
+ [](PyMlirContext &self, PyDialectRegistry ®istry) {
+ mlirContextAppendDialectRegistry(self.get(), registry);
+ },
+ nb::arg("registry"),
+ R"(
+ Appends the contents of a dialect registry to the context.
+
+ Args:
+ registry: A DialectRegistry containing dialects to append.)")
+ .def_prop_rw("emit_error_diagnostics",
+ &PyMlirContext::getEmitErrorDiagnostics,
+ &PyMlirContext::setEmitErrorDiagnostics,
+ R"(
+ Controls whether error diagnostics are emitted to diagnostic handlers.
+
+ By default, error diagnostics are captured and reported through MLIRError exceptions.)")
+ .def(
+ "load_all_available_dialects",
+ [](PyMlirContext &self) {
+ mlirContextLoadAllAvailableDialects(self.get());
+ },
+ R"(
+ Loads all dialects available in the registry into the context.
+
+ This eagerly loads all dialects that have been registered, making them
+ immediately available for use.)");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialectDescriptor
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
+ .def_prop_ro(
+ "namespace",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ return nb::str(ns.data, ns.length);
+ },
+ "Returns the namespace of the dialect.")
+ .def(
+ "__repr__",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ std::string repr("<DialectDescriptor ");
+ repr.append(ns.data, ns.length);
+ repr.append(">");
+ return repr;
+ },
+ nb::sig("def __repr__(self) -> str"),
+ "Returns a string representation of the dialect descriptor.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialects
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialects>(m, "Dialects")
+ .def(
+ "__getitem__",
+ [=](PyDialects &self, std::string keyName) {
+ MlirDialect dialect =
+ self.getDialectForKey(keyName, /*attrError=*/false);
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(keyName, std::move(descriptor));
+ },
+ "Gets a dialect by name using subscript notation.")
+ .def(
+ "__getattr__",
+ [=](PyDialects &self, std::string attrName) {
+ MlirDialect dialect =
+ self.getDialectForKey(attrName, /*attrError=*/true);
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(attrName, std::move(descriptor));
+ },
+ "Gets a dialect by name using attribute notation.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialect
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialect>(m, "Dialect")
+ .def(nb::init<nb::object>(), nb::arg("descriptor"),
+ "Creates a Dialect from a DialectDescriptor.")
+ .def_prop_ro(
+ "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
+ "Returns the DialectDescriptor for this dialect.")
+ .def(
+ "__repr__",
+ [](const nb::object &self) {
+ auto clazz = self.attr("__class__");
+ return nb::str("<Dialect ") +
+ self.attr("descriptor").attr("namespace") +
+ nb::str(" (class ") + clazz.attr("__module__") +
+ nb::str(".") + clazz.attr("__name__") + nb::str(")>");
+ },
+ nb::sig("def __repr__(self) -> str"),
+ "Returns a string representation of the dialect.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialectRegistry
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialectRegistry>(m, "DialectRegistry")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule,
+ "Gets a capsule wrapping the MlirDialectRegistry.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyDialectRegistry::createFromCapsule,
+ "Creates a DialectRegistry from a capsule wrapping "
+ "`MlirDialectRegistry`.")
+ .def(nb::init<>(), "Creates a new empty dialect registry.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Location
+ //----------------------------------------------------------------------------
+ nb::class_<PyLocation>(m, "Location")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule,
+ "Gets a capsule wrapping the MlirLocation.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule,
+ "Creates a Location from a capsule wrapping MlirLocation.")
+ .def("__enter__", &PyLocation::contextEnter,
+ "Enters the location as a context manager.")
+ .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
+ nb::arg("exc_value").none(), nb::arg("traceback").none(),
+ "Exits the location context manager.")
+ .def(
+ "__eq__",
+ [](PyLocation &self, PyLocation &other) -> bool {
+ return mlirLocationEqual(self, other);
+ },
+ "Compares two locations for equality.")
+ .def(
+ "__eq__", [](PyLocation &self, nb::object other) { return false; },
+ "Compares location with non-location object (always returns False).")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/) -> std::optional<PyLocation *> {
+ auto *loc = PyThreadContextEntry::getDefaultLocation();
+ if (!loc)
+ return std::nullopt;
+ return loc;
+ },
+ // clang-format off
+ nb::sig("def current(/) -> Location | None"),
+ // clang-format on
+ "Gets the Location bound to the current thread or raises ValueError.")
+ .def_static(
+ "unknown",
+ [](DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationUnknownGet(context->get()));
+ },
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing an unknown location.")
+ .def_static(
+ "callsite",
+ [](PyLocation callee, const std::vector<PyLocation> &frames,
+ DefaultingPyMlirContext context) {
+ if (frames.empty())
+ throw nb::value_error("No caller frames provided.");
+ MlirLocation caller = frames.back().get();
+ for (const PyLocation &frame :
+ llvm::reverse(llvm::ArrayRef(frames).drop_back()))
+ caller = mlirLocationCallSiteGet(frame.get(), caller);
+ return PyLocation(context->getRef(),
+ mlirLocationCallSiteGet(callee.get(), caller));
+ },
+ nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
+ "Gets a Location representing a caller and callsite.")
+ .def("is_a_callsite", mlirLocationIsACallSite,
+ "Returns True if this location is a CallSiteLoc.")
+ .def_prop_ro(
+ "callee",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationCallSiteGetCallee(self));
+ },
+ "Gets the callee location from a CallSiteLoc.")
+ .def_prop_ro(
+ "caller",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationCallSiteGetCaller(self));
+ },
+ "Gets the caller location from a CallSiteLoc.")
+ .def_static(
+ "file",
+ [](std::string filename, int line, int col,
+ DefaultingPyMlirContext context) {
+ return PyLocation(
+ context->getRef(),
+ mlirLocationFileLineColGet(
+ context->get(), toMlirStringRef(filename), line, col));
+ },
+ nb::arg("filename"), nb::arg("line"), nb::arg("col"),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a file, line and column.")
+ .def_static(
+ "file",
+ [](std::string filename, int startLine, int startCol, int endLine,
+ int endCol, DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationFileLineColRangeGet(
+ context->get(), toMlirStringRef(filename),
+ startLine, startCol, endLine, endCol));
+ },
+ nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
+ nb::arg("end_line"), nb::arg("end_col"),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a file, line and column range.")
+ .def("is_a_file", mlirLocationIsAFileLineColRange,
+ "Returns True if this location is a FileLineColLoc.")
+ .def_prop_ro(
+ "filename",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(
+ mlirLocationFileLineColRangeGetFilename(loc));
+ },
+ "Gets the filename from a FileLineColLoc.")
+ .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
+ "Gets the start line number from a `FileLineColLoc`.")
+ .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
+ "Gets the start column number from a `FileLineColLoc`.")
+ .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
+ "Gets the end line number from a `FileLineColLoc`.")
+ .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
+ "Gets the end column number from a `FileLineColLoc`.")
+ .def_static(
+ "fused",
+ [](const std::vector<PyLocation> &pyLocations,
+ std::optional<PyAttribute> metadata,
+ DefaultingPyMlirContext context) {
+ llvm::SmallVector<MlirLocation, 4> locations;
+ locations.reserve(pyLocations.size());
+ for (auto &pyLocation : pyLocations)
+ locations.push_back(pyLocation.get());
+ MlirLocation location = mlirLocationFusedGet(
+ context->get(), locations.size(), locations.data(),
+ metadata ? metadata->get() : MlirAttribute{0});
+ return PyLocation(context->getRef(), location);
+ },
+ nb::arg("locations"), nb::arg("metadata") = nb::none(),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a fused location with optional "
+ "metadata.")
+ .def("is_a_fused", mlirLocationIsAFused,
+ "Returns True if this location is a `FusedLoc`.")
+ .def_prop_ro(
+ "locations",
+ [](PyLocation &self) {
+ unsigned numLocations = mlirLocationFusedGetNumLocations(self);
+ std::vector<MlirLocation> locations(numLocations);
+ if (numLocations)
+ mlirLocationFusedGetLocations(self, locations.data());
+ std::vector<PyLocation> pyLocations{};
+ pyLocations.reserve(numLocations);
+ for (unsigned i = 0; i < numLocations; ++i)
+ pyLocations.emplace_back(self.getContext(), locations[i]);
+ return pyLocations;
+ },
+ "Gets the list of locations from a `FusedLoc`.")
+ .def_static(
+ "name",
+ [](std::string name, std::optional<PyLocation> childLoc,
+ DefaultingPyMlirContext context) {
+ return PyLocation(
+ context->getRef(),
+ mlirLocationNameGet(
+ context->get(), toMlirStringRef(name),
+ childLoc ? childLoc->get()
+ : mlirLocationUnknownGet(context->get())));
+ },
+ nb::arg("name"), nb::arg("childLoc") = nb::none(),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a named location with optional child "
+ "location.")
+ .def("is_a_name", mlirLocationIsAName,
+ "Returns True if this location is a `NameLoc`.")
+ .def_prop_ro(
+ "name_str",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(mlirLocationNameGetName(loc));
+ },
+ "Gets the name string from a `NameLoc`.")
+ .def_prop_ro(
+ "child_loc",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationNameGetChildLoc(self));
+ },
+ "Gets the child location from a `NameLoc`.")
+ .def_static(
+ "from_attr",
+ [](PyAttribute &attribute, DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationFromAttribute(attribute));
+ },
+ nb::arg("attribute"), nb::arg("context") = nb::none(),
+ "Gets a Location from a `LocationAttr`.")
+ .def_prop_ro(
+ "context",
+ [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Location`.")
+ .def_prop_ro(
+ "attr",
+ [](PyLocation &self) {
+ return PyAttribute(self.getContext(),
+ mlirLocationGetAttribute(self));
+ },
+ "Get the underlying `LocationAttr`.")
+ .def(
+ "emit_error",
+ [](PyLocation &self, std::string message) {
+ mlirEmitError(self, message.c_str());
+ },
+ nb::arg("message"),
+ R"(
+ Emits an error diagnostic at this location.
+
+ Args:
+ message: The error message to emit.)")
+ .def(
+ "__repr__",
+ [](PyLocation &self) {
+ PyPrintAccumulator printAccum;
+ mlirLocationPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly representation of the location.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Module
+ //----------------------------------------------------------------------------
+ nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule,
+ "Gets a capsule wrapping the MlirModule.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+ R"(
+ Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
+
+ This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
+ prevent double-frees (of the underlying `mlir::Module`).)")
+ .def("_clear_mlir_module", &PyModule::clearMlirModule,
+ R"(
+ Clears the internal MLIR module reference.
+
+ This is used internally to prevent double-free when ownership is transferred
+ via the C API capsule mechanism. Not intended for normal use.)")
+ .def_static(
+ "parse",
+ [](const std::string &moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParse(
+ context->get(), toMlirStringRef(moduleAsm));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ kModuleParseDocstring)
+ .def_static(
+ "parse",
+ [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParse(
+ context->get(), toMlirStringRef(moduleAsm));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ kModuleParseDocstring)
+ .def_static(
+ "parseFile",
+ [](const std::string &path, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParseFromFile(
+ context->get(), toMlirStringRef(path));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("path"), nb::arg("context") = nb::none(),
+ kModuleParseDocstring)
+ .def_static(
+ "create",
+ [](const std::optional<PyLocation> &loc)
+ -> nb::typed<nb::object, PyModule> {
+ PyLocation pyLoc = maybeGetTracebackLocation(loc);
+ MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("loc") = nb::none(), "Creates an empty module.")
+ .def_prop_ro(
+ "context",
+ [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that created the `Module`.")
+ .def_prop_ro(
+ "operation",
+ [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
+ return PyOperation::forOperation(self.getContext(),
+ mlirModuleGetOperation(self.get()),
+ self.getRef().releaseObject())
+ .releaseObject();
+ },
+ "Accesses the module as an operation.")
+ .def_prop_ro(
+ "body",
+ [](PyModule &self) {
+ PyOperationRef moduleOp = PyOperation::forOperation(
+ self.getContext(), mlirModuleGetOperation(self.get()),
+ self.getRef().releaseObject());
+ PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
+ return returnBlock;
+ },
+ "Return the block for this module.")
+ .def(
+ "dump",
+ [](PyModule &self) {
+ mlirOperationDump(mlirModuleGetOperation(self.get()));
+ },
+ kDumpDocstring)
+ .def(
+ "__str__",
+ [](const nb::object &self) {
+ // Defer to the operation's __str__.
+ return self.attr("operation").attr("__str__")();
+ },
+ nb::sig("def __str__(self) -> str"),
+ R"(
+ Gets the assembly form of the operation with default options.
+
+ If more advanced control over the assembly formatting or I/O options is needed,
+ use the dedicated print or get_asm method, which supports keyword arguments to
+ customize behavior.
+ )")
+ .def(
+ "__eq__",
+ [](PyModule &self, PyModule &other) {
+ return mlirModuleEqual(self.get(), other.get());
+ },
+ "other"_a, "Compares two modules for equality.")
+ .def(
+ "__hash__",
+ [](PyModule &self) { return mlirModuleHashValue(self.get()); },
+ "Returns the hash value of the module.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Operation.
+ //----------------------------------------------------------------------------
+ nb::class_<PyOperationBase>(m, "_OperationBase")
+ .def_prop_ro(
+ MLIR_PYTHON_CAPI_PTR_ATTR,
+ [](PyOperationBase &self) {
+ return self.getOperation().getCapsule();
+ },
+ "Gets a capsule wrapping the `MlirOperation`.")
+ .def(
+ "__eq__",
+ [](PyOperationBase &self, PyOperationBase &other) {
+ return mlirOperationEqual(self.getOperation().get(),
+ other.getOperation().get());
+ },
+ "Compares two operations for equality.")
+ .def(
+ "__eq__",
+ [](PyOperationBase &self, nb::object other) { return false; },
+ "Compares operation with non-operation object (always returns "
+ "False).")
+ .def(
+ "__hash__",
+ [](PyOperationBase &self) {
+ return mlirOperationHashValue(self.getOperation().get());
+ },
+ "Returns the hash value of the operation.")
+ .def_prop_ro(
+ "attributes",
+ [](PyOperationBase &self) {
+ return PyOpAttributeMap(self.getOperation().getRef());
+ },
+ "Returns a dictionary-like map of operation attributes.")
+ .def_prop_ro(
+ "context",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
+ PyOperation &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ return concreteOperation.getContext().getObject();
+ },
+ "Context that owns the operation.")
+ .def_prop_ro(
+ "name",
+ [](PyOperationBase &self) {
+ auto &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ MlirOperation operation = concreteOperation.get();
+ return mlirIdentifierStr(mlirOperationGetName(operation));
+ },
+ "Returns the fully qualified name of the operation.")
+ .def_prop_ro(
+ "operands",
+ [](PyOperationBase &self) {
+ return PyOpOperandList(self.getOperation().getRef());
+ },
+ "Returns the list of operation operands.")
+ .def_prop_ro(
+ "regions",
+ [](PyOperationBase &self) {
+ return PyRegionList(self.getOperation().getRef());
+ },
+ "Returns the list of operation regions.")
+ .def_prop_ro(
+ "results",
+ [](PyOperationBase &self) {
+ return PyOpResultList(self.getOperation().getRef());
+ },
+ "Returns the list of Operation results.")
+ .def_prop_ro(
+ "result",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
+ auto &operation = self.getOperation();
+ return PyOpResult(operation.getRef(), getUniqueResult(operation))
+ .maybeDownCast();
+ },
+ "Shortcut to get an op result if it has only one (throws an error "
+ "otherwise).")
+ .def_prop_rw(
+ "location",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ return PyLocation(operation.getContext(),
+ mlirOperationGetLocation(operation.get()));
+ },
+ [](PyOperationBase &self, const PyLocation &location) {
+ PyOperation &operation = self.getOperation();
+ mlirOperationSetLocation(operation.get(), location.get());
+ },
+ nb::for_getter("Returns the source location the operation was "
+ "defined or derived from."),
+ nb::for_setter("Sets the source location the operation was defined "
+ "or derived from."))
+ .def_prop_ro(
+ "parent",
+ [](PyOperationBase &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
+ auto parent = self.getOperation().getParentOperation();
+ if (parent)
+ return parent->getObject();
+ return {};
+ },
+ "Returns the parent operation, or `None` if at top level.")
+ .def(
+ "__str__",
+ [](PyOperationBase &self) {
+ return self.getAsm(/*binary=*/false,
+ /*largeElementsLimit=*/std::nullopt,
+ /*largeResourceLimit=*/std::nullopt,
+ /*enableDebugInfo=*/false,
+ /*prettyDebugInfo=*/false,
+ /*printGenericOpForm=*/false,
+ /*useLocalScope=*/false,
+ /*useNameLocAsPrefix=*/false,
+ /*assumeVerified=*/false,
+ /*skipRegions=*/false);
+ },
+ nb::sig("def __str__(self) -> str"),
+ "Returns the assembly form of the operation.")
+ .def("print",
+ nb::overload_cast<PyAsmState &, nb::object, bool>(
+ &PyOperationBase::print),
+ nb::arg("state"), nb::arg("file") = nb::none(),
+ nb::arg("binary") = false,
+ R"(
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ state: `AsmState` capturing the operation numbering and flags.
+ file: Optional file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
+ .def("print",
+ nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
+ bool, bool, bool, bool, bool, bool, nb::object,
+ bool, bool>(&PyOperationBase::print),
+ // Careful: Lots of arguments must match up with print method.
+ nb::arg("large_elements_limit") = nb::none(),
+ nb::arg("large_resource_limit") = nb::none(),
+ nb::arg("enable_debug_info") = false,
+ nb::arg("pretty_debug_info") = false,
+ nb::arg("print_generic_op_form") = false,
+ nb::arg("use_local_scope") = false,
+ nb::arg("use_name_loc_as_prefix") = false,
+ nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
+ nb::arg("binary") = false, nb::arg("skip_regions") = false,
+ R"(
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ large_elements_limit: Whether to elide elements attributes above this
+ number of elements. Defaults to None (no limit).
+ large_resource_limit: Whether to elide resource attributes above this
+ number of characters. Defaults to None (no limit). If large_elements_limit
+ is set and this is None, the behavior will be to use large_elements_limit
+ as large_resource_limit.
+ enable_debug_info: Whether to print debug/location information. Defaults
+ to False.
+ pretty_debug_info: Whether to format debug information for easier reading
+ by a human (warning: the result is unparseable). Defaults to False.
+ print_generic_op_form: Whether to print the generic assembly forms of all
+ ops. Defaults to False.
+ use_local_scope: Whether to print in a way that is more optimized for
+ multi-threaded access but may not be consistent with how the overall
+ module prints.
+ use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
+ prefixes for the SSA identifiers. Defaults to False.
+ assume_verified: By default, if not printing generic form, the verifier
+ will be run and if it fails, generic form will be printed with a comment
+ about failed verification. While a reasonable default for interactive use,
+ for systematic use, it is often better for the caller to verify explicitly
+ and report failures in a more robust fashion. Set this to True if doing this
+ in order to avoid running a redundant verification. If the IR is actually
+ invalid, behavior is undefined.
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ skip_regions: Whether to skip printing regions. Defaults to False.)")
+ .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
+ nb::arg("desired_version") = nb::none(),
+ R"(
+ Write the bytecode form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to.
+ desired_version: Optional version of bytecode to emit.
+ Returns:
+ The bytecode writer status.)")
+ .def("get_asm", &PyOperationBase::getAsm,
+ // Careful: Lots of arguments must match up with get_asm method.
+ nb::arg("binary") = false,
+ nb::arg("large_elements_limit") = nb::none(),
+ nb::arg("large_resource_limit") = nb::none(),
+ nb::arg("enable_debug_info") = false,
+ nb::arg("pretty_debug_info") = false,
+ nb::arg("print_generic_op_form") = false,
+ nb::arg("use_local_scope") = false,
+ nb::arg("use_name_loc_as_prefix") = false,
+ nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
+ R"(
+ Gets the assembly form of the operation with all options available.
+
+ Args:
+ binary: Whether to return a bytes (True) or str (False) object. Defaults to
+ False.
+ ... others ...: See the print() method for common keyword arguments for
+ configuring the printout.
+ Returns:
+ Either a bytes or str object, depending on the setting of the `binary`
+ argument.)")
+ .def("verify", &PyOperationBase::verify,
+ "Verify the operation. Raises MLIRError if verification fails, and "
+ "returns true otherwise.")
+ .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
+ "Puts self immediately after the other operation in its parent "
+ "block.")
+ .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
+ "Puts self immediately before the other operation in its parent "
+ "block.")
+ .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
+ nb::arg("other"),
+ R"(
+ Checks if this operation is before another in the same block.
+
+ Args:
+ other: Another operation in the same parent block.
+
+ Returns:
+ True if this operation is before `other` in the operation list of the parent block.)")
+ .def(
+ "clone",
+ [](PyOperationBase &self,
+ const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
+ return self.getOperation().clone(ip);
+ },
+ nb::arg("ip") = nb::none(),
+ R"(
+ Creates a deep copy of the operation.
+
+ Args:
+ ip: Optional insertion point where the cloned operation should be inserted.
+ If None, the current insertion point is used. If False, the operation
+ remains detached.
+
+ Returns:
+ A new Operation that is a clone of this operation.)")
+ .def(
+ "detach_from_parent",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ if (!operation.isAttached())
+ throw nb::value_error("Detached operation has no parent.");
+
+ operation.detachFromParent();
+ return operation.createOpView();
+ },
+ "Detaches the operation from its parent block.")
+ .def_prop_ro(
+ "attached",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ return operation.isAttached();
+ },
+ "Reports if the operation is attached to its parent block.")
+ .def(
+ "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
+ R"(
+ Erases the operation and frees its memory.
+
+ Note:
+ After erasing, any Python references to the operation become invalid.)")
+ .def("walk", &PyOperationBase::walk, nb::arg("callback"),
+ nb::arg("walk_order") = MlirWalkPostOrder,
+ // clang-format off
+ nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
+ // clang-format on
+ R"(
+ Walks the operation tree with a callback function.
+
+ Args:
+ callback: A callable that takes an Operation and returns a WalkResult.
+ walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+
+ nb::class_<PyOperation, PyOperationBase>(m, "Operation")
+ .def_static(
+ "create",
+ [](std::string_view name,
+ std::optional<std::vector<PyType *>> results,
+ std::optional<std::vector<PyValue *>> operands,
+ std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors, int regions,
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp,
+ bool inferType) -> nb::typed<nb::object, PyOperation> {
+ // Unpack/validate operands.
+ llvm::SmallVector<MlirValue, 4> mlirOperands;
+ if (operands) {
+ mlirOperands.reserve(operands->size());
+ for (PyValue *operand : *operands) {
+ if (!operand)
+ throw nb::value_error("operand value cannot be None");
+ mlirOperands.push_back(operand->get());
+ }
+ }
+
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ return PyOperation::create(name, results, mlirOperands, attributes,
+ successors, regions, pyLoc, maybeIp,
+ inferType);
+ },
+ nb::arg("name"), nb::arg("results") = nb::none(),
+ nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
+ nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
+ nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
+ nb::arg("infer_type") = false,
+ R"(
+ Creates a new operation.
+
+ Args:
+ name: Operation name (e.g. `dialect.operation`).
+ results: Optional sequence of Type representing op result types.
+ operands: Optional operands of the operation.
+ attributes: Optional Dict of {str: Attribute}.
+ successors: Optional List of Block for the operation's successors.
+ regions: Number of regions to create (default = 0).
+ location: Optional Location object (defaults to resolve from context manager).
+ ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
+ infer_type: Whether to infer result types (default = False).
+ Returns:
+ A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
+ .def_static(
+ "parse",
+ [](const std::string &sourceStr, const std::string &sourceName,
+ DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyOpView> {
+ return PyOperation::parse(context->getRef(), sourceStr, sourceName)
+ ->createOpView();
+ },
+ nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
+ nb::arg("context") = nb::none(),
+ "Parses an operation. Supports both text assembly format and binary "
+ "bytecode format.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule,
+ "Gets a capsule wrapping the MlirOperation.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyOperation::createFromCapsule,
+ "Creates an Operation from a capsule wrapping MlirOperation.")
+ .def_prop_ro(
+ "operation",
+ [](nb::object self) -> nb::typed<nb::object, PyOperation> {
+ return self;
+ },
+ "Returns self (the operation).")
+ .def_prop_ro(
+ "opview",
+ [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
+ return self.createOpView();
+ },
+ R"(
+ Returns an OpView of this operation.
+
+ Note:
+ If the operation has a registered and loaded dialect then this OpView will
+ be concrete wrapper class.)")
+ .def_prop_ro("block", &PyOperation::getBlock,
+ "Returns the block containing this operation.")
+ .def_prop_ro(
+ "successors",
+ [](PyOperationBase &self) {
+ return PyOpSuccessors(self.getOperation().getRef());
+ },
+ "Returns the list of Operation successors.")
+ .def("_set_invalid", &PyOperation::setInvalid,
+ "Invalidate the operation.");
+
+ auto opViewClass =
+ nb::class_<PyOpView, PyOperationBase>(m, "OpView")
+ .def(nb::init<nb::typed<nb::object, PyOperation>>(),
+ nb::arg("operation"))
+ .def(
+ "__init__",
+ [](PyOpView *self, std::string_view name,
+ std::tuple<int, bool> opRegionSpec,
+ nb::object operandSegmentSpecObj,
+ nb::object resultSegmentSpecObj,
+ std::optional<nb::list> resultTypeList, nb::list operandList,
+ std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors,
+ std::optional<int> regions,
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp) {
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ new (self) PyOpView(PyOpView::buildGeneric(
+ name, opRegionSpec, operandSegmentSpecObj,
+ resultSegmentSpecObj, resultTypeList, operandList,
+ attributes, successors, regions, pyLoc, maybeIp));
+ },
+ nb::arg("name"), nb::arg("opRegionSpec"),
+ nb::arg("operandSegmentSpecObj") = nb::none(),
+ nb::arg("resultSegmentSpecObj") = nb::none(),
+ nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
+ nb::arg("attributes") = nb::none(),
+ nb::arg("successors") = nb::none(),
+ nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
+ nb::arg("ip") = nb::none())
+ .def_prop_ro(
+ "operation",
+ [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
+ return self.getOperationObject();
+ })
+ .def_prop_ro("opview",
+ [](nb::object self) -> nb::typed<nb::object, PyOpView> {
+ return self;
+ })
+ .def(
+ "__str__",
+ [](PyOpView &self) { return nb::str(self.getOperationObject()); })
+ .def_prop_ro(
+ "successors",
+ [](PyOperationBase &self) {
+ return PyOpSuccessors(self.getOperation().getRef());
+ },
+ "Returns the list of Operation successors.")
+ .def(
+ "_set_invalid",
+ [](PyOpView &self) { self.getOperation().setInvalid(); },
+ "Invalidate the operation.");
+ opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
+ opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
+ opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
+ // It is faster to pass the operation_name, ods_regions, and
+ // ods_operand_segments/ods_result_segments as arguments to the constructor,
+ // rather than to access them as attributes.
+ opViewClass.attr("build_generic") = classmethod(
+ [](nb::handle cls, std::optional<nb::list> resultTypeList,
+ nb::list operandList, std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors,
+ std::optional<int> regions, std::optional<PyLocation> location,
+ const nb::object &maybeIp) {
+ std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
+ std::tuple<int, bool> opRegionSpec =
+ nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
+ nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
+ nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
+ resultSegmentSpec, resultTypeList,
+ operandList, attributes, successors,
+ regions, pyLoc, maybeIp);
+ },
+ nb::arg("cls"), nb::arg("results") = nb::none(),
+ nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
+ nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
+ nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
+ "Builds a specific, generated OpView based on class level attributes.");
+ opViewClass.attr("parse") = classmethod(
+ [](const nb::object &cls, const std::string &sourceStr,
+ const std::string &sourceName,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
+ PyOperationRef parsed =
+ PyOperation::parse(context->getRef(), sourceStr, sourceName);
+
+ // Check if the expected operation was parsed, and cast to to the
+ // appropriate `OpView` subclass if successful.
+ // NOTE: This accesses attributes that have been automatically added to
+ // `OpView` subclasses, and is not intended to be used on `OpView`
+ // directly.
+ std::string clsOpName =
+ nb::cast<std::string>(cls.attr("OPERATION_NAME"));
+ MlirStringRef identifier =
+ mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
+ std::string_view parsedOpName(identifier.data, identifier.length);
+ if (clsOpName != parsedOpName)
+ throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
+ parsedOpName + "'");
+ return PyOpView::constructDerived(cls, parsed.getObject());
+ },
+ nb::arg("cls"), nb::arg("source"), nb::kw_only(),
+ nb::arg("source_name") = "", nb::arg("context") = nb::none(),
+ "Parses a specific, generated OpView based on class level attributes.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyRegion.
+ //----------------------------------------------------------------------------
+ nb::class_<PyRegion>(m, "Region")
+ .def_prop_ro(
+ "blocks",
+ [](PyRegion &self) {
+ return PyBlockList(self.getParentOperation(), self.get());
+ },
+ "Returns a forward-optimized sequence of blocks.")
+ .def_prop_ro(
+ "owner",
+ [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the operation owning this region.")
+ .def(
+ "__iter__",
+ [](PyRegion &self) {
+ self.checkValid();
+ MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
+ return PyBlockIterator(self.getParentOperation(), firstBlock);
+ },
+ "Iterates over blocks in the region.")
+ .def(
+ "__eq__",
+ [](PyRegion &self, PyRegion &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two regions for pointer equality.")
+ .def(
+ "__eq__", [](PyRegion &self, nb::object &other) { return false; },
+ "Compares region with non-region object (always returns False).");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyBlock.
+ //----------------------------------------------------------------------------
+ nb::class_<PyBlock>(m, "Block")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule,
+ "Gets a capsule wrapping the MlirBlock.")
+ .def_prop_ro(
+ "owner",
+ [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the owning operation of this block.")
+ .def_prop_ro(
+ "region",
+ [](PyBlock &self) {
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ return PyRegion(self.getParentOperation(), region);
+ },
+ "Returns the owning region of this block.")
+ .def_prop_ro(
+ "arguments",
+ [](PyBlock &self) {
+ return PyBlockArgumentList(self.getParentOperation(), self.get());
+ },
+ "Returns a list of block arguments.")
+ .def(
+ "add_argument",
+ [](PyBlock &self, const PyType &type, const PyLocation &loc) {
+ return PyBlockArgument(self.getParentOperation(),
+ mlirBlockAddArgument(self.get(), type, loc));
+ },
+ "type"_a, "loc"_a,
+ R"(
+ Appends an argument of the specified type to the block.
+
+ Args:
+ type: The type of the argument to add.
+ loc: The source location for the argument.
+
+ Returns:
+ The newly added block argument.)")
+ .def(
+ "erase_argument",
+ [](PyBlock &self, unsigned index) {
+ return mlirBlockEraseArgument(self.get(), index);
+ },
+ nb::arg("index"),
+ R"(
+ Erases the argument at the specified index.
+
+ Args:
+ index: The index of the argument to erase.)")
+ .def_prop_ro(
+ "operations",
+ [](PyBlock &self) {
+ return PyOperationList(self.getParentOperation(), self.get());
+ },
+ "Returns a forward-optimized sequence of operations.")
+ .def_static(
+ "create_at_start",
+ [](PyRegion &parent, const nb::sequence &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ parent.checkValid();
+ MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
+ mlirRegionInsertOwnedBlock(parent, 0, block);
+ return PyBlock(parent.getParentOperation(), block);
+ },
+ nb::arg("parent"), nb::arg("arg_types") = nb::list(),
+ nb::arg("arg_locs") = std::nullopt,
+ "Creates and returns a new Block at the beginning of the given "
+ "region (with given argument types and locations).")
+ .def(
+ "append_to",
+ [](PyBlock &self, PyRegion ®ion) {
+ MlirBlock b = self.get();
+ if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
+ mlirBlockDetach(b);
+ mlirRegionAppendOwnedBlock(region.get(), b);
+ },
+ nb::arg("region"),
+ R"(
+ Appends this block to a region.
+
+ Transfers ownership if the block is currently owned by another region.
+
+ Args:
+ region: The region to append the block to.)")
+ .def(
+ "create_before",
+ [](PyBlock &self, const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ self.checkValid();
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
+ return PyBlock(self.getParentOperation(), block);
+ },
+ nb::arg("arg_types"), nb::kw_only(),
+ nb::arg("arg_locs") = std::nullopt,
+ "Creates and returns a new Block before this block "
+ "(with given argument types and locations).")
+ .def(
+ "create_after",
+ [](PyBlock &self, const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ self.checkValid();
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
+ return PyBlock(self.getParentOperation(), block);
+ },
+ nb::arg("arg_types"), nb::kw_only(),
+ nb::arg("arg_locs") = std::nullopt,
+ "Creates and returns a new Block after this block "
+ "(with given argument types and locations).")
+ .def(
+ "__iter__",
+ [](PyBlock &self) {
+ self.checkValid();
+ MlirOperation firstOperation =
+ mlirBlockGetFirstOperation(self.get());
+ return PyOperationIterator(self.getParentOperation(),
+ firstOperation);
+ },
+ "Iterates over operations in the block.")
+ .def(
+ "__eq__",
+ [](PyBlock &self, PyBlock &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two blocks for pointer equality.")
+ .def(
+ "__eq__", [](PyBlock &self, nb::object &other) { return false; },
+ "Compares block with non-block object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyBlock &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the block.")
+ .def(
+ "__str__",
+ [](PyBlock &self) {
+ self.checkValid();
+ PyPrintAccumulator printAccum;
+ mlirBlockPrint(self.get(), printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the block.")
+ .def(
+ "append",
+ [](PyBlock &self, PyOperationBase &operation) {
+ if (operation.getOperation().isAttached())
+ operation.getOperation().detachFromParent();
+
+ MlirOperation mlirOperation = operation.getOperation().get();
+ mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
+ operation.getOperation().setAttached(
+ self.getParentOperation().getObject());
+ },
+ nb::arg("operation"),
+ R"(
+ Appends an operation to this block.
+
+ If the operation is currently in another block, it will be moved.
+
+ Args:
+ operation: The operation to append to the block.)")
+ .def_prop_ro(
+ "successors",
+ [](PyBlock &self) {
+ return PyBlockSuccessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block successors.")
+ .def_prop_ro(
+ "predecessors",
+ [](PyBlock &self) {
+ return PyBlockPredecessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block predecessors.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyInsertionPoint.
+ //----------------------------------------------------------------------------
+
+ nb::class_<PyInsertionPoint>(m, "InsertionPoint")
+ .def(nb::init<PyBlock &>(), nb::arg("block"),
+ "Inserts after the last operation but still inside the block.")
+ .def("__enter__", &PyInsertionPoint::contextEnter,
+ "Enters the insertion point as a context manager.")
+ .def("__exit__", &PyInsertionPoint::contextExit,
+ nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+ nb::arg("traceback").none(),
+ "Exits the insertion point context manager.")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/) {
+ auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
+ if (!ip)
+ throw nb::value_error("No current InsertionPoint");
+ return ip;
+ },
+ nb::sig("def current(/) -> InsertionPoint"),
+ "Gets the InsertionPoint bound to the current thread or raises "
+ "ValueError if none has been set.")
+ .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
+ "Inserts before a referenced operation.")
+ .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
+ nb::arg("block"),
+ R"(
+ Creates an insertion point at the beginning of a block.
+
+ Args:
+ block: The block at whose beginning operations should be inserted.
+
+ Returns:
+ An InsertionPoint at the block's beginning.)")
+ .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
+ nb::arg("block"),
+ R"(
+ Creates an insertion point before a block's terminator.
+
+ Args:
+ block: The block whose terminator to insert before.
+
+ Returns:
+ An InsertionPoint before the terminator.
+
+ Raises:
+ ValueError: If the block has no terminator.)")
+ .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
+ R"(
+ Creates an insertion point immediately after an operation.
+
+ Args:
+ operation: The operation after which to insert.
+
+ Returns:
+ An InsertionPoint after the operation.)")
+ .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
+ R"(
+ Inserts an operation at this insertion point.
+
+ Args:
+ operation: The operation to insert.)")
+ .def_prop_ro(
+ "block", [](PyInsertionPoint &self) { return self.getBlock(); },
+ "Returns the block that this `InsertionPoint` points to.")
+ .def_prop_ro(
+ "ref_operation",
+ [](PyInsertionPoint &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
+ auto refOperation = self.getRefOperation();
+ if (refOperation)
+ return refOperation->getObject();
+ return {};
+ },
+ "The reference operation before which new operations are "
+ "inserted, or None if the insertion point is at the end of "
+ "the block.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyAttribute.
+ //----------------------------------------------------------------------------
+ nb::class_<PyAttribute>(m, "Attribute")
+ // Delegate to the PyAttribute copy constructor, which will also lifetime
+ // extend the backing context which owns the MlirAttribute.
+ .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
+ "Casts the passed attribute to the generic `Attribute`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule,
+ "Gets a capsule wrapping the MlirAttribute.")
+ .def_static(
+ MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule,
+ "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
+ .def_static(
+ "parse",
+ [](const std::string &attrSpec, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyAttribute> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr = mlirAttributeParseGet(
+ context->get(), toMlirStringRef(attrSpec));
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Unable to parse attribute", errors.take());
+ return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ "Parses an attribute from an assembly form. Raises an `MLIRError` on "
+ "failure.")
+ .def_prop_ro(
+ "context",
+ [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Attribute`.")
+ .def_prop_ro(
+ "type",
+ [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirAttributeGetType(self))
+ .maybeDownCast();
+ },
+ "Returns the type of the `Attribute`.")
+ .def(
+ "get_named",
+ [](PyAttribute &self, std::string name) {
+ return PyNamedAttribute(self, std::move(name));
+ },
+ nb::keep_alive<0, 1>(),
+ R"(
+ Binds a name to the attribute, creating a `NamedAttribute`.
+
+ Args:
+ name: The name to bind to the `Attribute`.
+
+ Returns:
+ A `NamedAttribute` with the given name and this attribute.)")
+ .def(
+ "__eq__",
+ [](PyAttribute &self, PyAttribute &other) { return self == other; },
+ "Compares two attributes for equality.")
+ .def(
+ "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
+ "Compares attribute with non-attribute object (always returns "
+ "False).")
+ .def(
+ "__hash__",
+ [](PyAttribute &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the attribute.")
+ .def(
+ "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
+ kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyAttribute &self) {
+ PyPrintAccumulator printAccum;
+ mlirAttributePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the Attribute.")
+ .def(
+ "__repr__",
+ [](PyAttribute &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, attribute values are generally considered useful and
+ // are printed. This may need to be re-evaluated if debug dumps end
+ // up being excessive.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Attribute(");
+ mlirAttributePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the attribute.")
+ .def_prop_ro(
+ "typeid",
+ [](PyAttribute &self) {
+ MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
+ assert(!mlirTypeIDIsNull(mlirTypeID) &&
+ "mlirTypeID was expected to be non-null.");
+ return PyTypeID(mlirTypeID);
+ },
+ "Returns the `TypeID` of the attribute.")
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the attribute to a more specific attribute if possible.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyNamedAttribute
+ //----------------------------------------------------------------------------
+ nb::class_<PyNamedAttribute>(m, "NamedAttribute")
+ .def(
+ "__repr__",
+ [](PyNamedAttribute &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("NamedAttribute(");
+ printAccum.parts.append(
+ nb::str(mlirIdentifierStr(self.namedAttr.name).data,
+ mlirIdentifierStr(self.namedAttr.name).length));
+ printAccum.parts.append("=");
+ mlirAttributePrint(self.namedAttr.attribute,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the named attribute.")
+ .def_prop_ro(
+ "name",
+ [](PyNamedAttribute &self) {
+ return mlirIdentifierStr(self.namedAttr.name);
+ },
+ "The name of the `NamedAttribute` binding.")
+ .def_prop_ro(
+ "attr",
+ [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
+ nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
+ "The underlying generic attribute of the `NamedAttribute` binding.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyType.
+ //----------------------------------------------------------------------------
+ nb::class_<PyType>(m, "Type")
+ // Delegate to the PyType copy constructor, which will also lifetime
+ // extend the backing context which owns the MlirType.
+ .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
+ "Casts the passed type to the generic `Type`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule,
+ "Gets a capsule wrapping the `MlirType`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule,
+ "Creates a Type from a capsule wrapping `MlirType`.")
+ .def_static(
+ "parse",
+ [](std::string typeSpec,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type =
+ mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Unable to parse type", errors.take());
+ return PyType(context.get()->getRef(), type).maybeDownCast();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ R"(
+ Parses the assembly form of a type.
+
+ Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
+
+ See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
+ .def_prop_ro(
+ "context",
+ [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Type`.")
+ .def(
+ "__eq__", [](PyType &self, PyType &other) { return self == other; },
+ "Compares two types for equality.")
+ .def(
+ "__eq__", [](PyType &self, nb::object &other) { return false; },
+ nb::arg("other").none(),
+ "Compares type with non-type object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyType &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the `Type`.")
+ .def(
+ "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyType &self) {
+ PyPrintAccumulator printAccum;
+ mlirTypePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the `Type`.")
+ .def(
+ "__repr__",
+ [](PyType &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, types are an exception as they typically have compact
+ // assembly forms and printing them is useful.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Type(");
+ mlirTypePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the `Type`.")
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyType &self) -> nb::typed<nb::object, PyType> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the Type to a more specific `Type` if possible.")
+ .def_prop_ro(
+ "typeid",
+ [](PyType &self) {
+ MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
+ if (!mlirTypeIDIsNull(mlirTypeID))
+ return PyTypeID(mlirTypeID);
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
+ throw nb::value_error(
+ (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
+ },
+ "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
+ "`Type` has no "
+ "`TypeID`.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyTypeID.
+ //----------------------------------------------------------------------------
+ nb::class_<PyTypeID>(m, "TypeID")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule,
+ "Gets a capsule wrapping the `MlirTypeID`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule,
+ "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
+ // Note, this tests whether the underlying TypeIDs are the same,
+ // not whether the wrapper MlirTypeIDs are the same, nor whether
+ // the Python objects are the same (i.e., PyTypeID is a value type).
+ .def(
+ "__eq__",
+ [](PyTypeID &self, PyTypeID &other) { return self == other; },
+ "Compares two `TypeID`s for equality.")
+ .def(
+ "__eq__",
+ [](PyTypeID &self, const nb::object &other) { return false; },
+ "Compares TypeID with non-TypeID object (always returns False).")
+ // Note, this gives the hash value of the underlying TypeID, not the
+ // hash value of the Python object, nor the hash value of the
+ // MlirTypeID wrapper.
+ .def(
+ "__hash__",
+ [](PyTypeID &self) {
+ return static_cast<size_t>(mlirTypeIDHashValue(self));
+ },
+ "Returns the hash value of the `TypeID`.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Value.
+ //----------------------------------------------------------------------------
+ m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
+
+ nb::class_<PyValue>(m, "Value", nb::is_generic(),
+ nb::sig("class Value(Generic[_T])"))
+ .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
+ "Creates a Value reference from another `Value`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
+ "Gets a capsule wrapping the `MlirValue`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule,
+ "Creates a `Value` from a capsule wrapping `MlirValue`.")
+ .def_prop_ro(
+ "context",
+ [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getParentOperation()->getContext().getObject();
+ },
+ "Context in which the value lives.")
+ .def(
+ "dump", [](PyValue &self) { mlirValueDump(self.get()); },
+ kDumpDocstring)
+ .def_prop_ro(
+ "owner",
+ [](PyValue &self) -> nb::object {
+ MlirValue v = self.get();
+ if (mlirValueIsAOpResult(v)) {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match "
+ "that in "
+ "the IR");
+ return self.getParentOperation().getObject();
+ }
+
+ if (mlirValueIsABlockArgument(v)) {
+ MlirBlock block = mlirBlockArgumentGetOwner(self.get());
+ return nb::cast(PyBlock(self.getParentOperation(), block));
+ }
+
+ assert(false && "Value must be a block argument or an op result");
+ return nb::none();
+ },
+ "Returns the owner of the value (`Operation` for results, `Block` "
+ "for "
+ "arguments).")
+ .def_prop_ro(
+ "uses",
+ [](PyValue &self) {
+ return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
+ },
+ "Returns an iterator over uses of this value.")
+ .def(
+ "__eq__",
+ [](PyValue &self, PyValue &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two values for pointer equality.")
+ .def(
+ "__eq__", [](PyValue &self, nb::object other) { return false; },
+ "Compares value with non-value object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyValue &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the value.")
+ .def(
+ "__str__",
+ [](PyValue &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Value(");
+ mlirValuePrint(self.get(), printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ R"(
+ Returns the string form of the value.
+
+ If the value is a block argument, this is the assembly form of its type and the
+ position in the argument list. If the value is an operation result, this is
+ equivalent to printing the operation that produced it.
+ )")
+ .def(
+ "get_name",
+ [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
+ PyPrintAccumulator printAccum;
+ MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ if (useNameLocAsPrefix)
+ mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
+ MlirAsmState valueState =
+ mlirAsmStateCreateForValue(self.get(), flags);
+ mlirValuePrintAsOperand(self.get(), valueState,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ mlirOpPrintingFlagsDestroy(flags);
+ mlirAsmStateDestroy(valueState);
+ return printAccum.join();
+ },
+ nb::arg("use_local_scope") = false,
+ nb::arg("use_name_loc_as_prefix") = false,
+ R"(
+ Returns the string form of value as an operand.
+
+ Args:
+ use_local_scope: Whether to use local scope for naming.
+ use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
+
+ Returns:
+ The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
+ .def(
+ "get_name",
+ [](PyValue &self, PyAsmState &state) {
+ PyPrintAccumulator printAccum;
+ MlirAsmState valueState = state.get();
+ mlirValuePrintAsOperand(self.get(), valueState,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ nb::arg("state"),
+ "Returns the string form of value as an operand (i.e., the ValueID).")
+ .def_prop_ro(
+ "type",
+ [](PyValue &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getParentOperation()->getContext(),
+ mlirValueGetType(self.get()))
+ .maybeDownCast();
+ },
+ "Returns the type of the value.")
+ .def(
+ "set_type",
+ [](PyValue &self, const PyType &type) {
+ mlirValueSetType(self.get(), type);
+ },
+ nb::arg("type"), "Sets the type of the value.",
+ nb::sig("def set_type(self, type: _T)"))
+ .def(
+ "replace_all_uses_with",
+ [](PyValue &self, PyValue &with) {
+ mlirValueReplaceAllUsesOfWith(self.get(), with.get());
+ },
+ "Replace all uses of value with the new value, updating anything in "
+ "the IR that uses `self` to use the other value instead.")
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, PyOperation &exception) {
+ MlirOperation exceptedUser = exception.get();
+ mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, const nb::list &exceptions) {
+ // Convert Python list to a SmallVector of MlirOperations
+ llvm::SmallVector<MlirOperation> exceptionOps;
+ for (nb::handle exception : exceptions) {
+ exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
+ }
+
+ mlirValueReplaceAllUsesExcept(
+ self, with, static_cast<intptr_t>(exceptionOps.size()),
+ exceptionOps.data());
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, PyOperation &exception) {
+ MlirOperation exceptedUser = exception.get();
+ mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with,
+ std::vector<PyOperation> &exceptions) {
+ // Convert Python list to a SmallVector of MlirOperations
+ llvm::SmallVector<MlirOperation> exceptionOps;
+ for (PyOperation &exception : exceptions)
+ exceptionOps.push_back(exception);
+ mlirValueReplaceAllUsesExcept(
+ self, with, static_cast<intptr_t>(exceptionOps.size()),
+ exceptionOps.data());
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyValue &self) -> nb::typed<nb::object, PyValue> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the `Value` to a more specific kind if possible.")
+ .def_prop_ro(
+ "location",
+ [](MlirValue self) {
+ return PyLocation(
+ PyMlirContext::forContext(mlirValueGetContext(self)),
+ mlirValueGetLocation(self));
+ },
+ "Returns the source location of the value.");
+
+ PyBlockArgument::bind(m);
+ PyOpResult::bind(m);
+ PyOpOperand::bind(m);
+
+ nb::class_<PyAsmState>(m, "AsmState")
+ .def(nb::init<PyValue &, bool>(), nb::arg("value"),
+ nb::arg("use_local_scope") = false,
+ R"(
+ Creates an `AsmState` for consistent SSA value naming.
+
+ Args:
+ value: The value to create state for.
+ use_local_scope: Whether to use local scope for naming.)")
+ .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
+ nb::arg("use_local_scope") = false,
+ R"(
+ Creates an AsmState for consistent SSA value naming.
+
+ Args:
+ op: The operation to create state for.
+ use_local_scope: Whether to use local scope for naming.)");
+
+ //----------------------------------------------------------------------------
+ // Mapping of SymbolTable.
+ //----------------------------------------------------------------------------
+ nb::class_<PySymbolTable>(m, "SymbolTable")
+ .def(nb::init<PyOperationBase &>(),
+ R"(
+ Creates a symbol table for an operation.
+
+ Args:
+ operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
+
+ Raises:
+ TypeError: If the operation is not a symbol table.)")
+ .def(
+ "__getitem__",
+ [](PySymbolTable &self,
+ const std::string &name) -> nb::typed<nb::object, PyOpView> {
+ return self.dunderGetItem(name);
+ },
+ R"(
+ Looks up a symbol by name in the symbol table.
+
+ Args:
+ name: The name of the symbol to look up.
+
+ Returns:
+ The operation defining the symbol.
+
+ Raises:
+ KeyError: If the symbol is not found.)")
+ .def("insert", &PySymbolTable::insert, nb::arg("operation"),
+ R"(
+ Inserts a symbol operation into the symbol table.
+
+ Args:
+ operation: An operation with a symbol name to insert.
+
+ Returns:
+ The symbol name attribute of the inserted operation.
+
+ Raises:
+ ValueError: If the operation does not have a symbol name.)")
+ .def("erase", &PySymbolTable::erase, nb::arg("operation"),
+ R"(
+ Erases a symbol operation from the symbol table.
+
+ Args:
+ operation: The symbol operation to erase.
+
+ Note:
+ The operation is also erased from the IR and invalidated.)")
+ .def("__delitem__", &PySymbolTable::dunderDel,
+ "Deletes a symbol by name from the symbol table.")
+ .def(
+ "__contains__",
+ [](PySymbolTable &table, const std::string &name) {
+ return !mlirOperationIsNull(mlirSymbolTableLookup(
+ table, mlirStringRefCreate(name.data(), name.length())));
+ },
+ "Checks if a symbol with the given name exists in the table.")
+ // Static helpers.
+ .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
+ nb::arg("symbol"), nb::arg("name"),
+ "Sets the symbol name for a symbol operation.")
+ .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
+ nb::arg("symbol"),
+ "Gets the symbol name from a symbol operation.")
+ .def_static("get_visibility", &PySymbolTable::getVisibility,
+ nb::arg("symbol"),
+ "Gets the visibility attribute of a symbol operation.")
+ .def_static("set_visibility", &PySymbolTable::setVisibility,
+ nb::arg("symbol"), nb::arg("visibility"),
+ "Sets the visibility attribute of a symbol operation.")
+ .def_static("replace_all_symbol_uses",
+ &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
+ nb::arg("new_symbol"), nb::arg("from_op"),
+ "Replaces all uses of a symbol with a new symbol name within "
+ "the given operation.")
+ .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
+ nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
+ nb::arg("callback"),
+ "Walks symbol tables starting from an operation with a "
+ "callback function.");
+
+ // Container bindings.
+ PyBlockArgumentList::bind(m);
+ PyBlockIterator::bind(m);
+ PyBlockList::bind(m);
+ PyBlockSuccessors::bind(m);
+ PyBlockPredecessors::bind(m);
+ PyOperationIterator::bind(m);
+ PyOperationList::bind(m);
+ PyOpAttributeMap::bind(m);
+ PyOpOperandIterator::bind(m);
+ PyOpOperandList::bind(m);
+ PyOpResultList::bind(m);
+ PyOpSuccessors::bind(m);
+ PyRegionIterator::bind(m);
+ PyRegionList::bind(m);
+
+ // Debug bindings.
+ PyGlobalDebugFlag::bind(m);
+
+ // Attribute builder getter.
+ PyAttrBuilderMap::bind(m);
+
+ // nb::register_exception_translator([](const std::exception_ptr &p,
+ // void *payload) {
+ // // We can't define exceptions with custom fields through pybind, so
+ // instead
+ // // the exception class is defined in python and imported here.
+ // try {
+ // if (p)
+ // std::rethrow_exception(p);
+ // } catch (const MLIRError &e) {
+ // nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ // .attr("MLIRError")(e.message, e.errorDiagnostics);
+ // PyErr_SetObject(PyExc_Exception, obj.ptr());
+ // }
+ // });
+}
+
+namespace mlir::python {
+void populateIRAffine(nb::module_ &m);
+void populateIRAttributes(nb::module_ &m);
+void populateIRInterfaces(nb::module_ &m);
+void populateIRTypes(nb::module_ &m);
+void registerMLIRErrorInIRCore();
+} // namespace mlir::python
+
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
@@ -158,4 +2415,18 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
+ registerMLIRErrorInIRCore();
+ nb::register_exception_translator([](const std::exception_ptr &p,
+ void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 572afa902746d..b1fd48cf410af 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,8 +8,8 @@
#include "Pass.h"
-#include "Globals.h"
-#include "IRModule.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/Pass.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..0221bd10e723e 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_PASS_H
#define MLIR_BINDINGS_PYTHON_PASS_H
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace mlir {
namespace python {
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0f0ed22c50fa9..6bf2b5dbdae3d 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,7 +8,7 @@
#include "Rewrite.h"
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index ae89e2b9589f1..f8ffdc7bdc458 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace mlir {
namespace python {
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 02124d12aea40..024d52dc97987 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -533,6 +533,10 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
MainModule.cpp
+ IRAffine.cpp
+ IRAttributes.cpp
+ IRInterfaces.cpp
+ IRTypes.cpp
Pass.cpp
Rewrite.cpp
@@ -1007,12 +1011,8 @@ get_property(NB_LIBRARY_TARGET_NAME TARGET MLIRPythonModules.extension._mlir.dso
list(GET NB_LIBRARY_TARGET_NAME 0 NB_LIBRARY_TARGET_NAME)
add_mlir_library_install(${NB_LIBRARY_TARGET_NAME})
add_mlir_library(MLIRPythonSupport
- ${PYTHON_SOURCE_DIR}/Globals.cpp
- ${PYTHON_SOURCE_DIR}/IRAffine.cpp
- ${PYTHON_SOURCE_DIR}/IRAttributes.cpp
${PYTHON_SOURCE_DIR}/IRCore.cpp
- ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp
- ${PYTHON_SOURCE_DIR}/IRTypes.cpp
+ ${PYTHON_SOURCE_DIR}/Globals.cpp
EXCLUDE_FROM_LIBMLIR
SHARED
LINK_COMPONENTS
@@ -1030,6 +1030,13 @@ set_target_properties(MLIRPythonSupport PROPERTIES
RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
)
+set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
+ LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ # Needed for windows (and doesn't hurt others).
+ RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+)
set(eh_rtti_enable)
if(MSVC)
set(eh_rtti_enable /EHsc /GR)
@@ -1051,4 +1058,4 @@ endif()
target_link_libraries(
MLIRPythonModules.extension._mlir.dso
PUBLIC MLIRPythonSupport)
-
+target_compile_definitions(MLIRPythonSupport PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
>From 1aa48c3fc1ff79582188094c90c06f66e4915e7c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 11 Dec 2025 13:46:11 -0800
Subject: [PATCH 3/3] works
---
mlir/cmake/modules/AddMLIRPython.cmake | 1 -
mlir/include/mlir/Bindings/Python/IRCore.h | 19 +++++++++++
mlir/lib/Bindings/Python/IRAttributes.cpp | 14 +-------
mlir/lib/Bindings/Python/IRCore.cpp | 5 +--
mlir/lib/Bindings/Python/IRTypes.cpp | 18 ++---------
mlir/lib/Bindings/Python/MainModule.cpp | 32 ++-----------------
mlir/lib/Bindings/Python/Pass.cpp | 3 +-
mlir/python/CMakeLists.txt | 7 ++--
mlir/test/python/dialects/python_test.py | 2 +-
.../python/lib/PythonTestModuleNanobind.cpp | 31 +++++++++++-------
10 files changed, 54 insertions(+), 78 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index b725a00c183ca..a4cbbb19203d3 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -609,7 +609,6 @@ function(add_mlir_python_common_capi_library name)
# Generate the aggregate .so that everything depends on.
add_mlir_aggregate(${name}
SHARED
- DISABLE_INSTALL
EMBED_LIBS ${_embed_libs}
)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 488196ea42e44..66a6272eaaf68 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1325,6 +1325,25 @@ struct MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
+inline void registerMLIRError() {
+ nanobind::register_exception_translator(
+ [](const std::exception_ptr &p, void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nanobind::object obj =
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
+}
+
+void registerMLIRErrorInCore();
+
//------------------------------------------------------------------------------
// Utilities.
//------------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 36367e658697c..4323374a5d5b7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1852,18 +1852,6 @@ void populateIRAttributes(nb::module_ &m) {
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
+ registerMLIRError();
}
} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 88cffb64906d7..ea1e62b8165ad 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -920,9 +920,6 @@ nb::object PyOperation::create(std::string_view name,
PyMlirContext::ErrorCapture errors(location.getContext());
MlirOperation operation = mlirOperationCreate(&state);
if (!operation.ptr) {
- for (auto take : errors.take()) {
- std::cout << take.message << "\n";
- }
throw MLIRError("Operation creation failed", errors.take());
}
PyOperationRef created =
@@ -1672,7 +1669,7 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
}
-void registerMLIRErrorInIRCore() {
+void registerMLIRErrorInCore() {
nb::register_exception_translator([](const std::exception_ptr &p,
void *payload) {
// We can't define exceptions with custom fields through pybind, so
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 294ab91a059e2..7d9a0f16c913a 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -13,10 +13,10 @@
#include <optional>
-#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
using namespace mlir;
@@ -1175,18 +1175,6 @@ void populateIRTypes(nb::module_ &m) {
PyTupleType::bind(m);
PyFunctionType::bind(m);
PyOpaqueType::bind(m);
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
-});
-}
+ registerMLIRError();
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 686c55ee1e6a8..643851fcaf046 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -2250,21 +2250,6 @@ static void populateIRCore(nb::module_ &m) {
// Attribute builder getter.
PyAttrBuilderMap::bind(m);
-
- // nb::register_exception_translator([](const std::exception_ptr &p,
- // void *payload) {
- // // We can't define exceptions with custom fields through pybind, so
- // instead
- // // the exception class is defined in python and imported here.
- // try {
- // if (p)
- // std::rethrow_exception(p);
- // } catch (const MLIRError &e) {
- // nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- // .attr("MLIRError")(e.message, e.errorDiagnostics);
- // PyErr_SetObject(PyExc_Exception, obj.ptr());
- // }
- // });
}
namespace mlir::python {
@@ -2272,7 +2257,6 @@ void populateIRAffine(nb::module_ &m);
void populateIRAttributes(nb::module_ &m);
void populateIRInterfaces(nb::module_ &m);
void populateIRTypes(nb::module_ &m);
-void registerMLIRErrorInIRCore();
} // namespace mlir::python
// -----------------------------------------------------------------------------
@@ -2415,18 +2399,6 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
- registerMLIRErrorInIRCore();
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
+ registerMLIRError();
+ registerMLIRErrorInCore();
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index b1fd48cf410af..3cfdfe49b4e3e 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,9 +8,9 @@
#include "Pass.h"
+#include "mlir-c/Pass.h"
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
-#include "mlir-c/Pass.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
@@ -254,4 +254,5 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
},
"Print the textual representation for this PassManager, suitable to "
"be passed to `parse` for round-tripping.");
+ registerMLIRError();
}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 024d52dc97987..39d61e0376c98 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -1018,10 +1018,10 @@ add_mlir_library(MLIRPythonSupport
LINK_COMPONENTS
Support
LINK_LIBS
+ Python::Module
${NB_LIBRARY_TARGET_NAME}
- MLIRCAPIIR
+ MLIRPythonCAPI
)
-target_link_libraries(MLIRPythonSupport PUBLIC ${NB_LIBRARY_TARGET_NAME})
nanobind_link_options(MLIRPythonSupport)
set_target_properties(MLIRPythonSupport PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
@@ -1058,4 +1058,7 @@ endif()
target_link_libraries(
MLIRPythonModules.extension._mlir.dso
PUBLIC MLIRPythonSupport)
+target_link_libraries(
+ MLIRPythonModules.extension._mlirPythonTestNanobind.dso
+ PUBLIC MLIRPythonSupport)
target_compile_definitions(MLIRPythonSupport PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index f0f74ebc12155..8c28bb3a0aa4d 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -646,7 +646,7 @@ def testCustomType():
try:
TestType(42)
except TypeError as e:
- assert "Expected an MLIR object (got 42)" in str(e)
+ assert "incompatible function arguments" in str(e)
except ValueError as e:
assert "Cannot cast type to TestType (from 42)" in str(e)
else:
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index a497fcccf13d7..e53f1ab3b4d3f 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -14,6 +14,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
@@ -26,6 +27,24 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
+struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirPythonTestTestTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](mlir::python::DefaultingPyMlirContext context) {
+ return PyTestType(context->getRef(),
+ mlirPythonTestTestTypeGet(context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
@@ -78,17 +97,7 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
// clang-format on
nb::arg("cls"), nb::arg("context").none() = nb::none());
- mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
- mlirPythonTestTestTypeGetTypeID)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPythonTestTestTypeGet(ctx));
- },
- // clang-format off
- nb::sig("def get(cls: object, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("context").none() = nb::none());
+ PyTestType::bind(m);
auto typeCls =
mlir_type_subclass(m, "TestIntegerRankedTensorType",
More information about the Mlir-commits
mailing list