[Mlir-commits] [mlir] 0e4be26 - [mlir][Python] fix dialect extensions which bind C types (#175405)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 10 21:24:59 PST 2026
Author: Maksim Levental
Date: 2026-01-10T21:24:55-08:00
New Revision: 0e4be262f4d0e74462b3f3d75e638e4ba3c56a4f
URL: https://github.com/llvm/llvm-project/commit/0e4be262f4d0e74462b3f3d75e638e4ba3c56a4f
DIFF: https://github.com/llvm/llvm-project/commit/0e4be262f4d0e74462b3f3d75e638e4ba3c56a4f.diff
LOG: [mlir][Python] fix dialect extensions which bind C types (#175405)
Fix some dialect bindings I missed in https://github.com/llvm/llvm-project/pull/174156 so they don't bind C structs (because that leads to multiple registration in the case when multiple packages are used simultaneously).
Added:
Modified:
mlir/lib/Bindings/Python/DialectIRDL.cpp
mlir/lib/Bindings/Python/DialectLLVM.cpp
mlir/lib/Bindings/Python/DialectLinalg.cpp
mlir/lib/Bindings/Python/DialectQuant.cpp
mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/Bindings/Python/TransformInterpreter.cpp
mlir/test/python/lib/PythonTestModuleNanobind.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp
index 08bcab97c03ec..85567d9986e3a 100644
--- a/mlir/lib/Bindings/Python/DialectIRDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp
@@ -9,19 +9,19 @@
#include "mlir-c/Dialect/IRDL.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
-using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using namespace mlir::python::nanobind_adaptors;
static void populateDialectIRDLSubmodule(nb::module_ &m) {
m.def(
"load_dialects",
- [](MlirModule module) {
- if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
+ [](PyModule &module) {
+ if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module.get())))
throw std::runtime_error(
"failed to load IRDL dialects from the input module");
},
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 0e4d81ce41b44..0c579cf261eca 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -43,8 +43,8 @@ struct StructType : PyConcreteType<StructType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get_literal",
- [](const std::vector<PyType> &elements, bool packed, MlirLocation loc,
- DefaultingPyMlirContext context) {
+ [](const std::vector<PyType> &elements, bool packed,
+ DefaultingPyLocation loc, DefaultingPyMlirContext context) {
python::CollectDiagnosticsToStringScope scope(
mlirLocationGetContext(loc));
std::vector<MlirType> elements_(elements.size());
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 0b079b404d42d..299961e100786 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -1,4 +1,4 @@
-//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
+//===- DialectLinalg.cpp - Nanobind module for Linalg dialect API support -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,16 +8,44 @@
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace linalg {
+
+struct PyLinalgContractionDimensions : MlirLinalgContractionDimensions {
+ PyLinalgContractionDimensions(const MlirLinalgContractionDimensions &dims) {
+ batch = dims.batch;
+ m = dims.m;
+ n = dims.n;
+ k = dims.k;
+ }
+};
+
+struct PyLinalgConvolutionDimensions : MlirLinalgConvolutionDimensions {
+ PyLinalgConvolutionDimensions(const MlirLinalgConvolutionDimensions &dims) {
+ batch = dims.batch;
+ outputImage = dims.outputImage;
+ outputChannel = dims.outputChannel;
+ filterLoop = dims.filterLoop;
+ inputChannel = dims.inputChannel;
+ depth = dims.depth;
+ strides = dims.strides;
+ dilations = dims.dilations;
+ }
+};
-static std::optional<MlirLinalgContractionDimensions>
-InferContractionDimensions(MlirOperation op) {
+static std::optional<PyLinalgContractionDimensions>
+InferContractionDimensions(PyOperationBase &op) {
MlirLinalgContractionDimensions dims =
- mlirLinalgInferContractionDimensions(op);
+ mlirLinalgInferContractionDimensions(op.getOperation());
// Detect "empty" result. This occurs when `op` is not a contraction op,
// or when `linalg::inferContractionDims` fails.
@@ -28,10 +56,10 @@ InferContractionDimensions(MlirOperation op) {
return dims;
}
-static std::optional<MlirLinalgConvolutionDimensions>
-InferConvolutionDimensions(MlirOperation op) {
+static std::optional<PyLinalgConvolutionDimensions>
+InferConvolutionDimensions(PyOperationBase &op) {
MlirLinalgConvolutionDimensions dims =
- mlirLinalgInferConvolutionDimensions(op);
+ mlirLinalgInferConvolutionDimensions(op.getOperation());
// Detect "empty" result. This occurs when `op` is not a convolution op,
// or when `linalg::inferConvolutionDims` fails.
@@ -51,27 +79,30 @@ InferConvolutionDimensions(MlirOperation op) {
static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
"fill_builtin_region",
- [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
+ [](PyOperationBase &op) {
+ mlirLinalgFillBuiltinNamedOpRegion(op.getOperation());
+ },
nb::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
- m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
- "Checks if the given operation is a Linalg contraction operation.",
- nb::arg("op"));
+ m.def(
+ "isa_contraction_op",
+ [](PyOperationBase &op) {
+ return mlirLinalgIsAContractionOp(op.getOperation());
+ },
+ "Checks if the given operation is a Linalg contraction operation.",
+ nb::arg("op"));
- nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
- .def_prop_ro("batch",
- [](const MlirLinalgContractionDimensions &self) {
- return self.batch;
- })
+ nb::class_<PyLinalgContractionDimensions>(m, "ContractionDimensions")
+ .def_prop_ro(
+ "batch",
+ [](const PyLinalgContractionDimensions &self) { return self.batch; })
.def_prop_ro(
- "m",
- [](const MlirLinalgContractionDimensions &self) { return self.m; })
+ "m", [](const PyLinalgContractionDimensions &self) { return self.m; })
.def_prop_ro(
- "n",
- [](const MlirLinalgContractionDimensions &self) { return self.n; })
- .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
+ "n", [](const PyLinalgContractionDimensions &self) { return self.n; })
+ .def_prop_ro("k", [](const PyLinalgContractionDimensions &self) {
return self.k;
});
@@ -82,14 +113,17 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
"infer_contraction_dimensions_from_maps",
- [](std::vector<MlirAffineMap> indexingMaps)
- -> std::optional<MlirLinalgContractionDimensions> {
+ [](std::vector<PyAffineMap> indexingMaps)
+ -> std::optional<PyLinalgContractionDimensions> {
if (indexingMaps.empty())
return std::nullopt;
+ std::vector<MlirAffineMap> indexingMaps_(indexingMaps.size());
+ std::copy(indexingMaps.begin(), indexingMaps.end(),
+ indexingMaps_.begin());
MlirLinalgContractionDimensions dims =
- mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(),
- indexingMaps.size());
+ mlirLinalgInferContractionDimensionsFromMaps(indexingMaps_.data(),
+ indexingMaps_.size());
// Detect "empty" result from invalid input or failed inference.
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
@@ -102,60 +136,67 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"maps.",
nb::arg("indexing_maps"));
- m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
- "Checks if the given operation is a Linalg convolution operation.",
- nb::arg("op"));
+ m.def(
+ "isa_convolution_op",
+ [](PyOperationBase &op) {
+ return mlirLinalgIsAConvolutionOp(op.getOperation());
+ },
+ "Checks if the given operation is a Linalg convolution operation.",
+ nb::arg("op"));
- nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
- .def_prop_ro("batch",
- [](const MlirLinalgConvolutionDimensions &self) {
- return self.batch;
- })
+ nb::class_<PyLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
+ .def_prop_ro(
+ "batch",
+ [](const PyLinalgConvolutionDimensions &self) { return self.batch; })
.def_prop_ro("output_image",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.outputImage;
})
.def_prop_ro("output_channel",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.outputChannel;
})
.def_prop_ro("filter_loop",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.filterLoop;
})
.def_prop_ro("input_channel",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.inputChannel;
})
- .def_prop_ro("depth",
- [](const MlirLinalgConvolutionDimensions &self) {
- return self.depth;
- })
+ .def_prop_ro(
+ "depth",
+ [](const PyLinalgConvolutionDimensions &self) { return self.depth; })
.def_prop_ro("strides",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.strides;
})
- .def_prop_ro("dilations",
- [](const MlirLinalgConvolutionDimensions &self) {
- return self.dilations;
- });
+ .def_prop_ro("dilations", [](const PyLinalgConvolutionDimensions &self) {
+ return self.dilations;
+ });
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
"Infers convolution dimensions", nb::arg("op"));
m.def(
"get_indexing_maps",
- [](MlirOperation op) -> std::optional<MlirAttribute> {
- MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
+ [](PyOperationBase &op) -> std::optional<PyArrayAttribute> {
+ MlirAttribute attr =
+ mlirLinalgGetIndexingMapsAttribute(op.getOperation());
if (mlirAttributeIsNull(attr))
return std::nullopt;
- return attr;
+ return PyArrayAttribute(op.getOperation().getContext(), attr);
},
"Returns the indexing_maps attribute for a linalg op.");
}
+} // namespace linalg
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsLinalg, m) {
m.doc() = "MLIR Linalg dialect.";
- populateDialectLinalgSubmodule(m);
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::linalg::
+ populateDialectLinalgSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 0cf0e767e82be..fba6a264ff007 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -381,8 +381,8 @@ struct UniformQuantizedSubChannelType
c.def_static(
"get",
[](unsigned flags, const PyType &storageType,
- const PyType &expressedType, MlirAttribute scales,
- MlirAttribute zeroPoints, std::vector<int32_t> quantizedDimensions,
+ const PyType &expressedType, PyAttribute scales,
+ PyAttribute zeroPoints, std::vector<int32_t> quantizedDimensions,
std::vector<int64_t> blockSizes, int64_t storageTypeMin,
int64_t storageTypeMax, DefaultingPyMlirContext context) {
return UniformQuantizedSubChannelType(
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index be0785b126eaa..01b7930deffd2 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -7,14 +7,16 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/ExecutionEngine.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
-using namespace mlir;
-using namespace mlir::python;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace execution_engine {
/// Owning Wrapper around an ExecutionEngine.
class PyExecutionEngine {
@@ -61,26 +63,31 @@ class PyExecutionEngine {
std::vector<nb::object> referencedObjects;
};
-} // namespace
+} // namespace execution_engine
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
/// Create the `mlir.execution_engine` module here.
NB_MODULE(_mlirExecutionEngine, m) {
m.doc() = "MLIR Execution Engine";
+ using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
+ using namespace execution_engine;
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
nb::class_<PyExecutionEngine>(m, "ExecutionEngine")
.def(
"__init__",
- [](PyExecutionEngine &self, MlirModule module, int optLevel,
+ [](PyExecutionEngine &self, PyModule &module, int optLevel,
const std::vector<std::string> &sharedLibPaths,
bool enableObjectDump, bool enablePIC) {
llvm::SmallVector<MlirStringRef, 4> libPaths;
for (const std::string &path : sharedLibPaths)
libPaths.push_back({path.c_str(), path.length()});
MlirExecutionEngine executionEngine = mlirExecutionEngineCreate(
- module, optLevel, libPaths.size(), libPaths.data(),
+ module.get(), optLevel, libPaths.size(), libPaths.data(),
enableObjectDump, enablePIC);
if (mlirExecutionEngineIsNull(executionEngine))
throw std::runtime_error(
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index a544648a47f45..f04f0b6271630 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2766,7 +2766,7 @@ void populateRoot(nb::module_ &m) {
"a dialect");
m.def(
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ [](PyTypeID mlirTypeID, bool replace) -> nb::object {
return nb::cpp_function([mlirTypeID, replace](
nb::callable typeCaster) -> nb::object {
PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
@@ -2781,7 +2781,7 @@ void populateRoot(nb::module_ &m) {
"Register a type caster for casting MLIR types to custom user types.");
m.def(
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ [](PyTypeID mlirTypeID, bool replace) -> nb::object {
return nb::cpp_function(
[mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
@@ -3231,7 +3231,7 @@ void populateIRCore(nb::module_ &m) {
"Returns True if this location is a FileLineColLoc.")
.def_prop_ro(
"filename",
- [](MlirLocation loc) {
+ [](PyLocation loc) {
return mlirIdentifierStr(
mlirLocationFileLineColRangeGetFilename(loc));
},
@@ -3295,7 +3295,7 @@ void populateIRCore(nb::module_ &m) {
"Returns True if this location is a `NameLoc`.")
.def_prop_ro(
"name_str",
- [](MlirLocation loc) {
+ [](PyLocation loc) {
return mlirIdentifierStr(mlirLocationNameGetName(loc));
},
"Gets the name string from a `NameLoc`.")
@@ -4698,7 +4698,7 @@ void populateIRCore(nb::module_ &m) {
"Downcasts the `Value` to a more specific kind if possible.")
.def_prop_ro(
"location",
- [](MlirValue self) {
+ [](PyValue self) {
return PyLocation(
PyMlirContext::forContext(mlirValueGetContext(self)),
mlirValueGetLocation(self));
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index af07dd53c2b54..5a0ab0559d74a 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -453,8 +453,8 @@ void PyVectorType::bindDerived(ClassTy &c) {
nb::arg("scalable_dims") = nb::none(),
nb::arg("context") = nb::none(), "Create a vector type")
.def_prop_ro("scalable",
- [](MlirType self) { return mlirVectorTypeIsScalable(self); })
- .def_prop_ro("scalable_dims", [](MlirType self) {
+ [](PyType self) { return mlirVectorTypeIsScalable(self); })
+ .def_prop_ro("scalable_dims", [](PyType self) {
std::vector<bool> scalableDims;
size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
scalableDims.reserve(rank);
@@ -745,9 +745,11 @@ void PyTupleType::bindDerived(ClassTy &c) {
"Create a tuple type");
c.def_static(
"get_tuple",
- [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
- MlirType t =
- mlirTupleTypeGet(context->get(), elements.size(), elements.data());
+ [](std::vector<PyType> elements, DefaultingPyMlirContext context) {
+ std::vector<MlirType> elements_(elements.size());
+ std::copy(elements.begin(), elements.end(), elements_.begin());
+ MlirType t = mlirTupleTypeGet(context->get(), elements_.size(),
+ elements_.data());
return PyTupleType(context->getRef(), t);
},
nb::arg("elements"), nb::arg("context") = nb::none(),
@@ -793,11 +795,15 @@ void PyFunctionType::bindDerived(ClassTy &c) {
"Gets a FunctionType from a list of input and result types");
c.def_static(
"get",
- [](std::vector<MlirType> inputs, std::vector<MlirType> results,
+ [](std::vector<PyType> inputs, std::vector<PyType> results,
DefaultingPyMlirContext context) {
+ std::vector<MlirType> inputs_(inputs.size());
+ std::copy(inputs.begin(), inputs.end(), inputs_.begin());
+ std::vector<MlirType> results_(results.size());
+ std::copy(results.begin(), results.end(), results_.begin());
MlirType t =
- mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
- results.size(), results.data());
+ mlirFunctionTypeGet(context->get(), inputs_.size(), inputs_.data(),
+ results_.size(), results_.data());
return PyFunctionType(context->getRef(), t);
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index 9e1eb37a816f6..a9f204ff9d0a5 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -14,54 +14,63 @@
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
-namespace {
-struct PyMlirTransformOptions {
- PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
- PyMlirTransformOptions(PyMlirTransformOptions &&other) {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace transform_interpreter {
+struct PyTransformOptions {
+ PyTransformOptions() { options = mlirTransformOptionsCreate(); };
+ PyTransformOptions(PyTransformOptions &&other) {
options = other.options;
other.options.ptr = nullptr;
}
- PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
+ PyTransformOptions(const PyTransformOptions &) = delete;
- ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
+ ~PyTransformOptions() { mlirTransformOptionsDestroy(options); }
MlirTransformOptions options;
};
-} // namespace
+} // namespace transform_interpreter
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
static void populateTransformInterpreterSubmodule(nb::module_ &m) {
- nb::class_<PyMlirTransformOptions>(m, "TransformOptions")
+ using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
+ using namespace transform_interpreter;
+ nb::class_<PyTransformOptions>(m, "TransformOptions")
.def(nb::init<>())
.def_prop_rw(
"expensive_checks",
- [](const PyMlirTransformOptions &self) {
+ [](const PyTransformOptions &self) {
return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
},
- [](PyMlirTransformOptions &self, bool value) {
+ [](PyTransformOptions &self, bool value) {
mlirTransformOptionsEnableExpensiveChecks(self.options, value);
})
.def_prop_rw(
"enforce_single_top_level_transform_op",
- [](const PyMlirTransformOptions &self) {
+ [](const PyTransformOptions &self) {
return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
self.options);
},
- [](PyMlirTransformOptions &self, bool value) {
+ [](PyTransformOptions &self, bool value) {
mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
value);
});
m.def(
"apply_named_sequence",
- [](MlirOperation payloadRoot, MlirOperation transformRoot,
- MlirOperation transformModule, const PyMlirTransformOptions &options) {
+ [](PyOperationBase &payloadRoot, PyOperationBase &transformRoot,
+ PyOperationBase &transformModule, const PyTransformOptions &options) {
mlir::python::CollectDiagnosticsToStringScope scope(
- mlirOperationGetContext(transformRoot));
+ mlirOperationGetContext(transformRoot.getOperation()));
// Calling back into Python to invalidate everything under the payload
// root. This is awkward, but we don't have access to PyMlirContext
@@ -69,7 +78,8 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
nb::object obj = nb::cast(payloadRoot);
MlirLogicalResult result = mlirTransformApplyNamedSequence(
- payloadRoot, transformRoot, transformModule, options.options);
+ payloadRoot.getOperation(), transformRoot.getOperation(),
+ transformModule.getOperation(), options.options);
if (mlirLogicalResultIsSuccess(result)) {
// Even in cases of success, we might have diagnostics to report:
std::string msg;
@@ -89,15 +99,16 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
},
nb::arg("payload_root"), nb::arg("transform_root"),
nb::arg("transform_module"),
- nb::arg("transform_options") = PyMlirTransformOptions());
+ nb::arg("transform_options") = PyTransformOptions());
m.def(
"copy_symbols_and_merge_into",
- [](MlirOperation target, MlirOperation other) {
+ [](PyOperationBase &target, PyOperationBase &other) {
mlir::python::CollectDiagnosticsToStringScope scope(
- mlirOperationGetContext(target));
+ mlirOperationGetContext(target.getOperation()));
- MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
+ MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(
+ target.getOperation(), other.getOperation());
if (mlirLogicalResultIsFailure(result)) {
throw nb::value_error(
("Failed to merge symbols.\nDiagnostic message " +
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index a296b5e814b4b..e9754749352b1 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -22,14 +22,16 @@
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
-
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace python_test {
static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
return mlirTypeIsARankedTensor(t) &&
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
-struct PyTestType
- : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
+struct PyTestType : PyConcreteType<PyTestType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPythonTestTestTypeGetTypeID;
@@ -39,8 +41,7 @@ struct PyTestType
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- context) {
+ [](DefaultingPyMlirContext context) {
return PyTestType(context->getRef(),
mlirPythonTestTestTypeGet(context.get()->get()));
},
@@ -49,9 +50,7 @@ struct PyTestType
};
struct PyTestIntegerRankedTensorType
- : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
- PyTestIntegerRankedTensorType,
- mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
+ : PyConcreteType<PyTestIntegerRankedTensorType, PyRankedTensorType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirRankedTensorTypeGetTypeID;
@@ -62,8 +61,7 @@ struct PyTestIntegerRankedTensorType
c.def_static(
"get",
[](std::vector<int64_t> shape, unsigned width,
- mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- ctx) {
+ DefaultingPyMlirContext ctx) {
MlirAttribute encoding = mlirAttributeGetNull();
return PyTestIntegerRankedTensorType(
ctx->getRef(),
@@ -76,9 +74,7 @@ struct PyTestIntegerRankedTensorType
}
};
-struct PyTestTensorValue
- : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
- PyTestTensorValue> {
+struct PyTestTensorValue : PyConcreteValue<PyTestTensorValue> {
static constexpr IsAFunctionTy isaFunction =
mlirTypeIsAPythonTestTestTensorValue;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -91,9 +87,7 @@ struct PyTestTensorValue
}
};
-class PyTestAttr
- : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
- PyTestAttr> {
+class PyTestAttr : public PyConcreteAttribute<PyTestAttr> {
public:
static constexpr IsAFunctionTy isaFunction =
mlirAttributeIsAPythonTestTestAttribute;
@@ -105,21 +99,23 @@ class PyTestAttr
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- context) {
+ [](DefaultingPyMlirContext context) {
return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
context.get()->get()));
},
nb::arg("context").none() = nb::none());
}
};
+} // namespace python_test
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirPythonTestNanobind, m) {
+ using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
m.def(
"register_python_test_dialect",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- context,
- bool load) {
+ [](DefaultingPyMlirContext context, bool load) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
mlirDialectHandleRegisterDialect(pythonTestDialect,
@@ -144,14 +140,14 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"test_diagnostics_with_errors_and_notes",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- ctx) {
+ [](DefaultingPyMlirContext ctx) {
mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
throw nb::value_error(handler.takeMessage().c_str());
},
nb::arg("context").none() = nb::none());
+ using namespace python_test;
PyTestAttr::bind(m);
PyTestType::bind(m);
PyTestIntegerRankedTensorType::bind(m);
More information about the Mlir-commits
mailing list