[Mlir-commits] [mlir] [mlir][Python] fix dialect extensions which bind C types (PR #175405)
Maksim Levental
llvmlistbot at llvm.org
Sat Jan 10 20:47:27 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/175405
>From 3c47ce6fb2be159b609283d81323f223a78eb1c3 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 10 Jan 2026 19:03:17 -0800
Subject: [PATCH 1/4] [mlir][Python] fix linalg dialect extension
---
mlir/lib/Bindings/Python/DialectLinalg.cpp | 131 ++++++++++++---------
1 file changed, 78 insertions(+), 53 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 0b079b404d42d..a88bac4a8d68b 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -1,4 +1,4 @@
-//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
+//===- DialectLinalg.cpp - Nanobind module for Linalg dialect API support -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,16 +8,26 @@
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace linalg {
-static std::optional<MlirLinalgContractionDimensions>
-InferContractionDimensions(MlirOperation op) {
+struct PyLinalgContractionDimensions : MlirLinalgContractionDimensions {};
+
+struct PyLinalgConvolutionDimensions : MlirLinalgConvolutionDimensions {};
+
+static std::optional<PyLinalgContractionDimensions>
+InferContractionDimensions(PyOperationBase &op) {
MlirLinalgContractionDimensions dims =
- mlirLinalgInferContractionDimensions(op);
+ mlirLinalgInferContractionDimensions(op.getOperation());
// Detect "empty" result. This occurs when `op` is not a contraction op,
// or when `linalg::inferContractionDims` fails.
@@ -25,13 +35,13 @@ InferContractionDimensions(MlirOperation op) {
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
- return dims;
+ return PyLinalgContractionDimensions{dims.batch, dims.m, dims.k, dims.n};
}
-static std::optional<MlirLinalgConvolutionDimensions>
-InferConvolutionDimensions(MlirOperation op) {
+static std::optional<PyLinalgConvolutionDimensions>
+InferConvolutionDimensions(PyOperationBase &op) {
MlirLinalgConvolutionDimensions dims =
- mlirLinalgInferConvolutionDimensions(op);
+ mlirLinalgInferConvolutionDimensions(op.getOperation());
// Detect "empty" result. This occurs when `op` is not a convolution op,
// or when `linalg::inferConvolutionDims` fails.
@@ -45,33 +55,38 @@ InferConvolutionDimensions(MlirOperation op) {
return std::nullopt;
}
- return dims;
+ return PyLinalgConvolutionDimensions{
+ dims.batch, dims.outputImage, dims.outputChannel, dims.filterLoop,
+ dims.inputChannel, dims.depth, dims.strides, dims.dilations};
}
static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
"fill_builtin_region",
- [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
+ [](PyOperationBase &op) {
+ mlirLinalgFillBuiltinNamedOpRegion(op.getOperation());
+ },
nb::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
- m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
- "Checks if the given operation is a Linalg contraction operation.",
- nb::arg("op"));
+ m.def(
+ "isa_contraction_op",
+ [](PyOperationBase &op) {
+ return mlirLinalgIsAContractionOp(op.getOperation());
+ },
+ "Checks if the given operation is a Linalg contraction operation.",
+ nb::arg("op"));
- nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
- .def_prop_ro("batch",
- [](const MlirLinalgContractionDimensions &self) {
- return self.batch;
- })
+ nb::class_<PyLinalgContractionDimensions>(m, "ContractionDimensions")
+ .def_prop_ro(
+ "batch",
+ [](const PyLinalgContractionDimensions &self) { return self.batch; })
.def_prop_ro(
- "m",
- [](const MlirLinalgContractionDimensions &self) { return self.m; })
+ "m", [](const PyLinalgContractionDimensions &self) { return self.m; })
.def_prop_ro(
- "n",
- [](const MlirLinalgContractionDimensions &self) { return self.n; })
- .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
+ "n", [](const PyLinalgContractionDimensions &self) { return self.n; })
+ .def_prop_ro("k", [](const PyLinalgContractionDimensions &self) {
return self.k;
});
@@ -82,80 +97,90 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
"infer_contraction_dimensions_from_maps",
- [](std::vector<MlirAffineMap> indexingMaps)
- -> std::optional<MlirLinalgContractionDimensions> {
+ [](std::vector<PyAffineMap> indexingMaps)
+ -> std::optional<PyLinalgContractionDimensions> {
if (indexingMaps.empty())
return std::nullopt;
+ std::vector<MlirAffineMap> indexingMaps_(indexingMaps.size());
+ std::copy(indexingMaps.begin(), indexingMaps.end(),
+ indexingMaps_.begin());
MlirLinalgContractionDimensions dims =
- mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(),
- indexingMaps.size());
+ mlirLinalgInferContractionDimensionsFromMaps(indexingMaps_.data(),
+ indexingMaps_.size());
// Detect "empty" result from invalid input or failed inference.
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
- return dims;
+ return PyLinalgContractionDimensions{dims.batch, dims.m, dims.k,
+ dims.n};
},
"Infers contraction dimensions (batch/m/n/k) from a list of affine "
"maps.",
nb::arg("indexing_maps"));
- m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
- "Checks if the given operation is a Linalg convolution operation.",
- nb::arg("op"));
+ m.def(
+ "isa_convolution_op",
+ [](PyOperationBase &op) {
+ return mlirLinalgIsAConvolutionOp(op.getOperation());
+ },
+ "Checks if the given operation is a Linalg convolution operation.",
+ nb::arg("op"));
- nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
- .def_prop_ro("batch",
- [](const MlirLinalgConvolutionDimensions &self) {
- return self.batch;
- })
+ nb::class_<PyLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
+ .def_prop_ro(
+ "batch",
+ [](const PyLinalgConvolutionDimensions &self) { return self.batch; })
.def_prop_ro("output_image",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.outputImage;
})
.def_prop_ro("output_channel",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.outputChannel;
})
.def_prop_ro("filter_loop",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.filterLoop;
})
.def_prop_ro("input_channel",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.inputChannel;
})
- .def_prop_ro("depth",
- [](const MlirLinalgConvolutionDimensions &self) {
- return self.depth;
- })
+ .def_prop_ro(
+ "depth",
+ [](const PyLinalgConvolutionDimensions &self) { return self.depth; })
.def_prop_ro("strides",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.strides;
})
- .def_prop_ro("dilations",
- [](const MlirLinalgConvolutionDimensions &self) {
- return self.dilations;
- });
+ .def_prop_ro("dilations", [](const PyLinalgConvolutionDimensions &self) {
+ return self.dilations;
+ });
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
"Infers convolution dimensions", nb::arg("op"));
m.def(
"get_indexing_maps",
- [](MlirOperation op) -> std::optional<MlirAttribute> {
- MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
+ [](PyOperationBase &op) -> std::optional<PyAttribute> {
+ MlirAttribute attr =
+ mlirLinalgGetIndexingMapsAttribute(op.getOperation());
if (mlirAttributeIsNull(attr))
return std::nullopt;
- return attr;
+ return PyArrayAttribute(op.getOperation().getContext(), attr);
},
"Returns the indexing_maps attribute for a linalg op.");
}
+} // namespace linalg
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsLinalg, m) {
m.doc() = "MLIR Linalg dialect.";
- populateDialectLinalgSubmodule(m);
+ mlir::python::mlir::linalg::populateDialectLinalgSubmodule(m);
}
>From f18b5f73e435f3caf106afb37bdcb21aa35c1aee Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 10 Jan 2026 19:03:57 -0800
Subject: [PATCH 2/4] [mlir][Python] fix linalg and python_test dialect
extensions
---
mlir/lib/Bindings/Python/DialectLinalg.cpp | 11 ++---
.../python/lib/PythonTestModuleNanobind.cpp | 42 +++++++++----------
2 files changed, 25 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index a88bac4a8d68b..68d355c557729 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -35,7 +35,7 @@ InferContractionDimensions(PyOperationBase &op) {
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
- return PyLinalgContractionDimensions{dims.batch, dims.m, dims.k, dims.n};
+ return PyLinalgContractionDimensions{dims.batch, dims.m, dims.n, dims.k};
}
static std::optional<PyLinalgConvolutionDimensions>
@@ -114,8 +114,8 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
- return PyLinalgContractionDimensions{dims.batch, dims.m, dims.k,
- dims.n};
+ return PyLinalgContractionDimensions{dims.batch, dims.m, dims.n,
+ dims.k};
},
"Infers contraction dimensions (batch/m/n/k) from a list of affine "
"maps.",
@@ -165,7 +165,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
"get_indexing_maps",
- [](PyOperationBase &op) -> std::optional<PyAttribute> {
+ [](PyOperationBase &op) -> std::optional<PyArrayAttribute> {
MlirAttribute attr =
mlirLinalgGetIndexingMapsAttribute(op.getOperation());
if (mlirAttributeIsNull(attr))
@@ -182,5 +182,6 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
NB_MODULE(_mlirDialectsLinalg, m) {
m.doc() = "MLIR Linalg dialect.";
- mlir::python::mlir::linalg::populateDialectLinalgSubmodule(m);
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::linalg::
+ populateDialectLinalgSubmodule(m);
}
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index a296b5e814b4b..e9754749352b1 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -22,14 +22,16 @@
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
-
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace python_test {
static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
return mlirTypeIsARankedTensor(t) &&
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
-struct PyTestType
- : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
+struct PyTestType : PyConcreteType<PyTestType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPythonTestTestTypeGetTypeID;
@@ -39,8 +41,7 @@ struct PyTestType
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- context) {
+ [](DefaultingPyMlirContext context) {
return PyTestType(context->getRef(),
mlirPythonTestTestTypeGet(context.get()->get()));
},
@@ -49,9 +50,7 @@ struct PyTestType
};
struct PyTestIntegerRankedTensorType
- : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
- PyTestIntegerRankedTensorType,
- mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
+ : PyConcreteType<PyTestIntegerRankedTensorType, PyRankedTensorType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirRankedTensorTypeGetTypeID;
@@ -62,8 +61,7 @@ struct PyTestIntegerRankedTensorType
c.def_static(
"get",
[](std::vector<int64_t> shape, unsigned width,
- mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- ctx) {
+ DefaultingPyMlirContext ctx) {
MlirAttribute encoding = mlirAttributeGetNull();
return PyTestIntegerRankedTensorType(
ctx->getRef(),
@@ -76,9 +74,7 @@ struct PyTestIntegerRankedTensorType
}
};
-struct PyTestTensorValue
- : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
- PyTestTensorValue> {
+struct PyTestTensorValue : PyConcreteValue<PyTestTensorValue> {
static constexpr IsAFunctionTy isaFunction =
mlirTypeIsAPythonTestTestTensorValue;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -91,9 +87,7 @@ struct PyTestTensorValue
}
};
-class PyTestAttr
- : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
- PyTestAttr> {
+class PyTestAttr : public PyConcreteAttribute<PyTestAttr> {
public:
static constexpr IsAFunctionTy isaFunction =
mlirAttributeIsAPythonTestTestAttribute;
@@ -105,21 +99,23 @@ class PyTestAttr
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- context) {
+ [](DefaultingPyMlirContext context) {
return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
context.get()->get()));
},
nb::arg("context").none() = nb::none());
}
};
+} // namespace python_test
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirPythonTestNanobind, m) {
+ using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
m.def(
"register_python_test_dialect",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- context,
- bool load) {
+ [](DefaultingPyMlirContext context, bool load) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
mlirDialectHandleRegisterDialect(pythonTestDialect,
@@ -144,14 +140,14 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"test_diagnostics_with_errors_and_notes",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- ctx) {
+ [](DefaultingPyMlirContext ctx) {
mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
throw nb::value_error(handler.takeMessage().c_str());
},
nb::arg("context").none() = nb::none());
+ using namespace python_test;
PyTestAttr::bind(m);
PyTestType::bind(m);
PyTestIntegerRankedTensorType::bind(m);
>From 3647dd33c8567a0f44cc6fbddfd85259f5ce6e04 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 10 Jan 2026 20:16:27 -0800
Subject: [PATCH 3/4] add ExecutionEngineModule.cpp
---
.../Bindings/Python/ExecutionEngineModule.cpp | 19 +++++++++++++------
1 file changed, 13 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index be0785b126eaa..01b7930deffd2 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -7,14 +7,16 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/ExecutionEngine.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
-using namespace mlir;
-using namespace mlir::python;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace execution_engine {
/// Owning Wrapper around an ExecutionEngine.
class PyExecutionEngine {
@@ -61,26 +63,31 @@ class PyExecutionEngine {
std::vector<nb::object> referencedObjects;
};
-} // namespace
+} // namespace execution_engine
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
/// Create the `mlir.execution_engine` module here.
NB_MODULE(_mlirExecutionEngine, m) {
m.doc() = "MLIR Execution Engine";
+ using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
+ using namespace execution_engine;
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
nb::class_<PyExecutionEngine>(m, "ExecutionEngine")
.def(
"__init__",
- [](PyExecutionEngine &self, MlirModule module, int optLevel,
+ [](PyExecutionEngine &self, PyModule &module, int optLevel,
const std::vector<std::string> &sharedLibPaths,
bool enableObjectDump, bool enablePIC) {
llvm::SmallVector<MlirStringRef, 4> libPaths;
for (const std::string &path : sharedLibPaths)
libPaths.push_back({path.c_str(), path.length()});
MlirExecutionEngine executionEngine = mlirExecutionEngineCreate(
- module, optLevel, libPaths.size(), libPaths.data(),
+ module.get(), optLevel, libPaths.size(), libPaths.data(),
enableObjectDump, enablePIC);
if (mlirExecutionEngineIsNull(executionEngine))
throw std::runtime_error(
>From 3102c201a924c6b819c389755d075ce6c1e539d6 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 10 Jan 2026 20:46:19 -0800
Subject: [PATCH 4/4] fix some more
---
mlir/lib/Bindings/Python/DialectIRDL.cpp | 8 +--
mlir/lib/Bindings/Python/DialectLLVM.cpp | 4 +-
mlir/lib/Bindings/Python/DialectQuant.cpp | 4 +-
mlir/lib/Bindings/Python/IRCore.cpp | 10 ++--
mlir/lib/Bindings/Python/IRTypes.cpp | 22 +++++---
.../Bindings/Python/TransformInterpreter.cpp | 51 +++++++++++--------
6 files changed, 58 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp
index 08bcab97c03ec..85567d9986e3a 100644
--- a/mlir/lib/Bindings/Python/DialectIRDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp
@@ -9,19 +9,19 @@
#include "mlir-c/Dialect/IRDL.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
-using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using namespace mlir::python::nanobind_adaptors;
static void populateDialectIRDLSubmodule(nb::module_ &m) {
m.def(
"load_dialects",
- [](MlirModule module) {
- if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
+ [](PyModule &module) {
+ if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module.get())))
throw std::runtime_error(
"failed to load IRDL dialects from the input module");
},
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 0e4d81ce41b44..0c579cf261eca 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -43,8 +43,8 @@ struct StructType : PyConcreteType<StructType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get_literal",
- [](const std::vector<PyType> &elements, bool packed, MlirLocation loc,
- DefaultingPyMlirContext context) {
+ [](const std::vector<PyType> &elements, bool packed,
+ DefaultingPyLocation loc, DefaultingPyMlirContext context) {
python::CollectDiagnosticsToStringScope scope(
mlirLocationGetContext(loc));
std::vector<MlirType> elements_(elements.size());
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 0cf0e767e82be..fba6a264ff007 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -381,8 +381,8 @@ struct UniformQuantizedSubChannelType
c.def_static(
"get",
[](unsigned flags, const PyType &storageType,
- const PyType &expressedType, MlirAttribute scales,
- MlirAttribute zeroPoints, std::vector<int32_t> quantizedDimensions,
+ const PyType &expressedType, PyAttribute scales,
+ PyAttribute zeroPoints, std::vector<int32_t> quantizedDimensions,
std::vector<int64_t> blockSizes, int64_t storageTypeMin,
int64_t storageTypeMax, DefaultingPyMlirContext context) {
return UniformQuantizedSubChannelType(
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index a544648a47f45..f04f0b6271630 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2766,7 +2766,7 @@ void populateRoot(nb::module_ &m) {
"a dialect");
m.def(
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ [](PyTypeID mlirTypeID, bool replace) -> nb::object {
return nb::cpp_function([mlirTypeID, replace](
nb::callable typeCaster) -> nb::object {
PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
@@ -2781,7 +2781,7 @@ void populateRoot(nb::module_ &m) {
"Register a type caster for casting MLIR types to custom user types.");
m.def(
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ [](PyTypeID mlirTypeID, bool replace) -> nb::object {
return nb::cpp_function(
[mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
@@ -3231,7 +3231,7 @@ void populateIRCore(nb::module_ &m) {
"Returns True if this location is a FileLineColLoc.")
.def_prop_ro(
"filename",
- [](MlirLocation loc) {
+ [](PyLocation loc) {
return mlirIdentifierStr(
mlirLocationFileLineColRangeGetFilename(loc));
},
@@ -3295,7 +3295,7 @@ void populateIRCore(nb::module_ &m) {
"Returns True if this location is a `NameLoc`.")
.def_prop_ro(
"name_str",
- [](MlirLocation loc) {
+ [](PyLocation loc) {
return mlirIdentifierStr(mlirLocationNameGetName(loc));
},
"Gets the name string from a `NameLoc`.")
@@ -4698,7 +4698,7 @@ void populateIRCore(nb::module_ &m) {
"Downcasts the `Value` to a more specific kind if possible.")
.def_prop_ro(
"location",
- [](MlirValue self) {
+ [](PyValue self) {
return PyLocation(
PyMlirContext::forContext(mlirValueGetContext(self)),
mlirValueGetLocation(self));
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index af07dd53c2b54..5a0ab0559d74a 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -453,8 +453,8 @@ void PyVectorType::bindDerived(ClassTy &c) {
nb::arg("scalable_dims") = nb::none(),
nb::arg("context") = nb::none(), "Create a vector type")
.def_prop_ro("scalable",
- [](MlirType self) { return mlirVectorTypeIsScalable(self); })
- .def_prop_ro("scalable_dims", [](MlirType self) {
+ [](PyType self) { return mlirVectorTypeIsScalable(self); })
+ .def_prop_ro("scalable_dims", [](PyType self) {
std::vector<bool> scalableDims;
size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
scalableDims.reserve(rank);
@@ -745,9 +745,11 @@ void PyTupleType::bindDerived(ClassTy &c) {
"Create a tuple type");
c.def_static(
"get_tuple",
- [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
- MlirType t =
- mlirTupleTypeGet(context->get(), elements.size(), elements.data());
+ [](std::vector<PyType> elements, DefaultingPyMlirContext context) {
+ std::vector<MlirType> elements_(elements.size());
+ std::copy(elements.begin(), elements.end(), elements_.begin());
+ MlirType t = mlirTupleTypeGet(context->get(), elements_.size(),
+ elements_.data());
return PyTupleType(context->getRef(), t);
},
nb::arg("elements"), nb::arg("context") = nb::none(),
@@ -793,11 +795,15 @@ void PyFunctionType::bindDerived(ClassTy &c) {
"Gets a FunctionType from a list of input and result types");
c.def_static(
"get",
- [](std::vector<MlirType> inputs, std::vector<MlirType> results,
+ [](std::vector<PyType> inputs, std::vector<PyType> results,
DefaultingPyMlirContext context) {
+ std::vector<MlirType> inputs_(inputs.size());
+ std::copy(inputs.begin(), inputs.end(), inputs_.begin());
+ std::vector<MlirType> results_(results.size());
+ std::copy(results.begin(), results.end(), results_.begin());
MlirType t =
- mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
- results.size(), results.data());
+ mlirFunctionTypeGet(context->get(), inputs_.size(), inputs_.data(),
+ results_.size(), results_.data());
return PyFunctionType(context->getRef(), t);
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index 9e1eb37a816f6..a9f204ff9d0a5 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -14,54 +14,63 @@
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
-namespace {
-struct PyMlirTransformOptions {
- PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
- PyMlirTransformOptions(PyMlirTransformOptions &&other) {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace transform_interpreter {
+struct PyTransformOptions {
+ PyTransformOptions() { options = mlirTransformOptionsCreate(); };
+ PyTransformOptions(PyTransformOptions &&other) {
options = other.options;
other.options.ptr = nullptr;
}
- PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
+ PyTransformOptions(const PyTransformOptions &) = delete;
- ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
+ ~PyTransformOptions() { mlirTransformOptionsDestroy(options); }
MlirTransformOptions options;
};
-} // namespace
+} // namespace transform_interpreter
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
static void populateTransformInterpreterSubmodule(nb::module_ &m) {
- nb::class_<PyMlirTransformOptions>(m, "TransformOptions")
+ using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
+ using namespace transform_interpreter;
+ nb::class_<PyTransformOptions>(m, "TransformOptions")
.def(nb::init<>())
.def_prop_rw(
"expensive_checks",
- [](const PyMlirTransformOptions &self) {
+ [](const PyTransformOptions &self) {
return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
},
- [](PyMlirTransformOptions &self, bool value) {
+ [](PyTransformOptions &self, bool value) {
mlirTransformOptionsEnableExpensiveChecks(self.options, value);
})
.def_prop_rw(
"enforce_single_top_level_transform_op",
- [](const PyMlirTransformOptions &self) {
+ [](const PyTransformOptions &self) {
return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
self.options);
},
- [](PyMlirTransformOptions &self, bool value) {
+ [](PyTransformOptions &self, bool value) {
mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
value);
});
m.def(
"apply_named_sequence",
- [](MlirOperation payloadRoot, MlirOperation transformRoot,
- MlirOperation transformModule, const PyMlirTransformOptions &options) {
+ [](PyOperationBase &payloadRoot, PyOperationBase &transformRoot,
+ PyOperationBase &transformModule, const PyTransformOptions &options) {
mlir::python::CollectDiagnosticsToStringScope scope(
- mlirOperationGetContext(transformRoot));
+ mlirOperationGetContext(transformRoot.getOperation()));
// Calling back into Python to invalidate everything under the payload
// root. This is awkward, but we don't have access to PyMlirContext
@@ -69,7 +78,8 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
nb::object obj = nb::cast(payloadRoot);
MlirLogicalResult result = mlirTransformApplyNamedSequence(
- payloadRoot, transformRoot, transformModule, options.options);
+ payloadRoot.getOperation(), transformRoot.getOperation(),
+ transformModule.getOperation(), options.options);
if (mlirLogicalResultIsSuccess(result)) {
// Even in cases of success, we might have diagnostics to report:
std::string msg;
@@ -89,15 +99,16 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
},
nb::arg("payload_root"), nb::arg("transform_root"),
nb::arg("transform_module"),
- nb::arg("transform_options") = PyMlirTransformOptions());
+ nb::arg("transform_options") = PyTransformOptions());
m.def(
"copy_symbols_and_merge_into",
- [](MlirOperation target, MlirOperation other) {
+ [](PyOperationBase &target, PyOperationBase &other) {
mlir::python::CollectDiagnosticsToStringScope scope(
- mlirOperationGetContext(target));
+ mlirOperationGetContext(target.getOperation()));
- MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
+ MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(
+ target.getOperation(), other.getOperation());
if (mlirLogicalResultIsFailure(result)) {
throw nb::value_error(
("Failed to merge symbols.\nDiagnostic message " +
More information about the Mlir-commits
mailing list