[llvm-branch-commits] [mlir] [mlir][Python] port in-tree dialect extensions to use MLIRPythonSupport (PR #174156)
Maksim Levental via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Jan 3 20:04:09 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/174156
>From bcb510a239ae586ee8c16a69d81283116192feb1 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 31 Dec 2025 14:20:39 -0800
Subject: [PATCH 1/3] [mlir][Python] move IRTypes and IRAttributes to public
headers
---
mlir/test/python/lib/PythonTestModuleNanobind.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index a296b5e814b4b..b229c02ccf5e6 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -156,4 +156,4 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
PyTestType::bind(m);
PyTestIntegerRankedTensorType::bind(m);
PyTestTensorValue::bind(m);
-}
+}
\ No newline at end of file
>From b6af0195e1ab989efcd0c81b3a6ea21fb61eeccd Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 29 Dec 2025 11:14:00 -0800
Subject: [PATCH 2/3] [mlir][Python] port dialect extensions to use core
PyConcreteType, PyConcreteAttribute
---
mlir/lib/Bindings/Python/DialectAMDGPU.cpp | 111 ++-
mlir/lib/Bindings/Python/DialectGPU.cpp | 152 ++--
mlir/lib/Bindings/Python/DialectLLVM.cpp | 297 ++++---
mlir/lib/Bindings/Python/DialectNVGPU.cpp | 50 +-
mlir/lib/Bindings/Python/DialectPDL.cpp | 228 +++--
mlir/lib/Bindings/Python/DialectQuant.cpp | 810 ++++++++++--------
mlir/lib/Bindings/Python/DialectSMT.cpp | 89 +-
.../Bindings/Python/DialectSparseTensor.cpp | 266 +++---
mlir/lib/Bindings/Python/DialectTransform.cpp | 249 +++---
.../dialects/transform/extras/__init__.py | 11 +-
mlir/test/python/dialects/pdl_types.py | 211 ++---
11 files changed, 1439 insertions(+), 1035 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
index 26ffc0e427e41..de24dfa9660c1 100644
--- a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
@@ -8,58 +8,97 @@
#include "mlir-c/Dialect/AMDGPU.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
- auto amdgpuTDMBaseType =
- mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType,
- mlirAMDGPUTDMBaseTypeGetTypeID);
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace amdgpu {
+struct TDMBaseType : PyConcreteType<TDMBaseType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAMDGPUTDMBaseType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMBaseTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMBaseType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMBaseType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType elementType, MlirContext ctx) {
- return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
- },
- "Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
- nb::arg("element_type"), nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &elementType, DefaultingPyMlirContext context) {
+ return TDMBaseType(
+ context->getRef(),
+ mlirAMDGPUTDMBaseTypeGet(context.get()->get(), elementType));
+ },
+ "Gets an instance of TDMBaseType in the same context",
+ nb::arg("element_type"), nb::arg("context").none() = nb::none());
+ }
+};
- auto amdgpuTDMDescriptorType = mlir_type_subclass(
- m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType,
- mlirAMDGPUTDMDescriptorTypeGetTypeID);
+struct TDMDescriptorType : PyConcreteType<TDMDescriptorType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAAMDGPUTDMDescriptorType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMDescriptorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMDescriptorType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMDescriptorType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
- },
- "Gets an instance of TDMDescriptorType in the same context",
- nb::arg("cls"), nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return TDMDescriptorType(
+ context->getRef(),
+ mlirAMDGPUTDMDescriptorTypeGet(context.get()->get()));
+ },
+ "Gets an instance of TDMDescriptorType in the same context",
+ nb::arg("context").none() = nb::none());
+ }
+};
- auto amdgpuTDMGatherBaseType = mlir_type_subclass(
- m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType,
- mlirAMDGPUTDMGatherBaseTypeGetTypeID);
+struct TDMGatherBaseType : PyConcreteType<TDMGatherBaseType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAAMDGPUTDMGatherBaseType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMGatherBaseTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMGatherBaseType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMGatherBaseType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType elementType, MlirType indexType,
- MlirContext ctx) {
- return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType));
- },
- "Gets an instance of TDMGatherBaseType in the same context",
- nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
- nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &elementType, const PyType &indexType,
+ DefaultingPyMlirContext context) {
+ return TDMGatherBaseType(
+ context->getRef(),
+ mlirAMDGPUTDMGatherBaseTypeGet(context.get()->get(), elementType,
+ indexType));
+ },
+ "Gets an instance of TDMGatherBaseType in the same context",
+ nb::arg("element_type"), nb::arg("index_type"),
+ nb::arg("context").none() = nb::none());
+ }
};
+static void populateDialectAMDGPUSubmodule(nb::module_ &m) {
+ TDMBaseType::bind(m);
+ TDMDescriptorType::bind(m);
+ TDMGatherBaseType::bind(m);
+}
+} // namespace amdgpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
NB_MODULE(_mlirDialectsAMDGPU, m) {
m.doc() = "MLIR AMDGPU dialect.";
- populateDialectAMDGPUSubmodule(m);
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::amdgpu::
+ populateDialectAMDGPUSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 2568d535edb5a..3ea8edec7b136 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -9,83 +9,105 @@
#include "mlir-c/Dialect/GPU.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
-
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace gpu {
// -----------------------------------------------------------------------------
-// Module initialization.
+// AsyncTokenType
// -----------------------------------------------------------------------------
-NB_MODULE(_mlirDialectsGPU, m) {
- m.doc() = "MLIR GPU Dialect";
- //===-------------------------------------------------------------------===//
- // AsyncTokenType
- //===-------------------------------------------------------------------===//
+struct AsyncTokenType : PyConcreteType<AsyncTokenType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAGPUAsyncTokenType;
+ static constexpr const char *pyClassName = "AsyncTokenType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AsyncTokenType(context->getRef(),
+ mlirGPUAsyncTokenTypeGet(context.get()->get()));
+ },
+ "Gets an instance of AsyncTokenType in the same context",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// ObjectAttr
+//===-------------------------------------------------------------------===//
+
+struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAGPUObjectAttr;
+ static constexpr const char *pyClassName = "ObjectAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
- auto mlirGPUAsyncTokenType =
- mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](MlirAttribute target, uint32_t format, const nb::bytes &object,
+ std::optional<MlirAttribute> mlirObjectProps,
+ std::optional<MlirAttribute> mlirKernelsAttr,
+ DefaultingPyMlirContext context) {
+ MlirStringRef objectStrRef = mlirStringRefCreate(
+ static_cast<char *>(const_cast<void *>(object.data())),
+ object.size());
+ return ObjectAttr(
+ context->getRef(),
+ mlirGPUObjectAttrGetWithKernels(
+ mlirAttributeGetContext(target), target, format, objectStrRef,
+ mlirObjectProps.has_value() ? *mlirObjectProps
+ : MlirAttribute{nullptr},
+ mlirKernelsAttr.has_value() ? *mlirKernelsAttr
+ : MlirAttribute{nullptr}));
+ },
+ "target"_a, "format"_a, "object"_a, "properties"_a = nb::none(),
+ "kernels"_a = nb::none(), "context"_a = nb::none(),
+ "Gets a gpu.object from parameters.");
- mlirGPUAsyncTokenType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirGPUAsyncTokenTypeGet(ctx));
- },
- "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
- nb::arg("ctx") = nb::none());
+ c.def_prop_ro("target", [](MlirAttribute self) {
+ return mlirGPUObjectAttrGetTarget(self);
+ });
+ c.def_prop_ro("format", [](MlirAttribute self) {
+ return mlirGPUObjectAttrGetFormat(self);
+ });
+ c.def_prop_ro("object", [](MlirAttribute self) {
+ MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
+ return nb::bytes(stringRef.data, stringRef.length);
+ });
+ c.def_prop_ro("properties", [](MlirAttribute self) -> nb::object {
+ if (mlirGPUObjectAttrHasProperties(self))
+ return nb::cast(mlirGPUObjectAttrGetProperties(self));
+ return nb::none();
+ });
+ c.def_prop_ro("kernels", [](MlirAttribute self) -> nb::object {
+ if (mlirGPUObjectAttrHasKernels(self))
+ return nb::cast(mlirGPUObjectAttrGetKernels(self));
+ return nb::none();
+ });
+ }
+};
+} // namespace gpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
- //===-------------------------------------------------------------------===//
- // ObjectAttr
- //===-------------------------------------------------------------------===//
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
+NB_MODULE(_mlirDialectsGPU, m) {
+ m.doc() = "MLIR GPU Dialect";
- mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirAttribute target, uint32_t format,
- const nb::bytes &object,
- std::optional<MlirAttribute> mlirObjectProps,
- std::optional<MlirAttribute> mlirKernelsAttr) {
- MlirStringRef objectStrRef = mlirStringRefCreate(
- static_cast<char *>(const_cast<void *>(object.data())),
- object.size());
- return cls(mlirGPUObjectAttrGetWithKernels(
- mlirAttributeGetContext(target), target, format, objectStrRef,
- mlirObjectProps.has_value() ? *mlirObjectProps
- : MlirAttribute{nullptr},
- mlirKernelsAttr.has_value() ? *mlirKernelsAttr
- : MlirAttribute{nullptr}));
- },
- "cls"_a, "target"_a, "format"_a, "object"_a,
- "properties"_a = nb::none(), "kernels"_a = nb::none(),
- "Gets a gpu.object from parameters.")
- .def_property_readonly(
- "target",
- [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
- .def_property_readonly(
- "format",
- [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
- .def_property_readonly(
- "object",
- [](MlirAttribute self) {
- MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
- return nb::bytes(stringRef.data, stringRef.length);
- })
- .def_property_readonly("properties",
- [](MlirAttribute self) -> nb::object {
- if (mlirGPUObjectAttrHasProperties(self))
- return nb::cast(
- mlirGPUObjectAttrGetProperties(self));
- return nb::none();
- })
- .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
- if (mlirGPUObjectAttrHasKernels(self))
- return nb::cast(mlirGPUObjectAttrGetKernels(self));
- return nb::none();
- });
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::gpu::AsyncTokenType::bind(m);
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::gpu::ObjectAttr::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 05681cecf82b3..d4eb078c0f55c 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -13,149 +13,176 @@
#include "mlir-c/Support.h"
#include "mlir-c/Target/LLVMIR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
-
using namespace llvm;
using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
-
- //===--------------------------------------------------------------------===//
- // StructType
- //===--------------------------------------------------------------------===//
-
- auto llvmStructType = mlir_type_subclass(
- m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
-
- llvmStructType
- .def_classmethod(
- "get_literal",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirLocation loc) {
- CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
- MlirType type = mlirLLVMStructTypeLiteralGetChecked(
- loc, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "loc"_a = nb::none())
- .def_classmethod(
- "get_literal_unchecked",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirContext context) {
- CollectDiagnosticsToStringScope scope(context);
-
- MlirType type = mlirLLVMStructTypeLiteralGet(
- context, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "context"_a = nb::none());
-
- llvmStructType.def_classmethod(
- "get_identified",
- [](const nb::object &cls, const std::string &name, MlirContext context) {
- return cls(mlirLLVMStructTypeIdentifiedGet(
- context, mlirStringRefCreate(name.data(), name.size())));
- },
- "cls"_a, "name"_a, nb::kw_only(), "context"_a = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace llvm {
+//===--------------------------------------------------------------------===//
+// StructType
+//===--------------------------------------------------------------------===//
+
+struct StructType : PyConcreteType<StructType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMStructType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirLLVMStructTypeGetTypeID;
+ static constexpr const char *pyClassName = "StructType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_literal",
+ [](const std::vector<MlirType> &elements, bool packed, MlirLocation loc,
+ DefaultingPyMlirContext context) {
+ python::CollectDiagnosticsToStringScope scope(
+ mlirLocationGetContext(loc));
+
+ MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+ loc, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return StructType(context->getRef(), type);
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false, "loc"_a = nb::none(),
+ "context"_a = nb::none());
+
+ c.def_static(
+ "get_literal_unchecked",
+ [](const std::vector<MlirType> &elements, bool packed,
+ DefaultingPyMlirContext context) {
+ python::CollectDiagnosticsToStringScope scope(context.get()->get());
+
+ MlirType type = mlirLLVMStructTypeLiteralGet(
+ context.get()->get(), elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return StructType(context->getRef(), type);
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
+
+ c.def_static(
+ "get_identified",
+ [](const std::string &name, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeIdentifiedGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.size())));
+ },
+ "name"_a, nb::kw_only(), "context"_a = nb::none());
+
+ c.def_static(
+ "get_opaque",
+ [](const std::string &name, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeOpaqueGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.size())));
+ },
+ "name"_a, "context"_a = nb::none());
+
+ c.def(
+ "set_body",
+ [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+ MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+ self, elements.size(), elements.data(), packed);
+ if (!mlirLogicalResultIsSuccess(result)) {
+ throw nb::value_error(
+ "Struct body already set to different content.");
+ }
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false);
+
+ c.def_static(
+ "new_identified",
+ [](const std::string &name, const std::vector<MlirType> &elements,
+ bool packed, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeIdentifiedNewGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.length()),
+ elements.size(), elements.data(), packed));
+ },
+ "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
+
+ c.def_prop_ro("name", [](PyType type) -> std::optional<std::string> {
+ if (mlirLLVMStructTypeIsLiteral(type))
+ return std::nullopt;
+
+ MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+ return StringRef(stringRef.data, stringRef.length).str();
+ });
+
+ c.def_prop_ro("body", [](PyType type) -> nb::object {
+ // Don't crash in absence of a body.
+ if (mlirLLVMStructTypeIsOpaque(type))
+ return nb::none();
+
+ nb::list body;
+ for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type);
+ i < e; ++i) {
+ body.append(mlirLLVMStructTypeGetElementType(type, i));
+ }
+ return body;
+ });
+
+ c.def_prop_ro("packed",
+ [](PyType type) { return mlirLLVMStructTypeIsPacked(type); });
+
+ c.def_prop_ro("opaque",
+ [](PyType type) { return mlirLLVMStructTypeIsOpaque(type); });
+ }
+};
+
+//===--------------------------------------------------------------------===//
+// PointerType
+//===--------------------------------------------------------------------===//
+
+struct PointerType : PyConcreteType<PointerType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMPointerType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirLLVMPointerTypeGetTypeID;
+ static constexpr const char *pyClassName = "PointerType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::optional<unsigned> addressSpace,
+ DefaultingPyMlirContext context) {
+ python::CollectDiagnosticsToStringScope scope(context.get()->get());
+ MlirType type = mlirLLVMPointerTypeGet(
+ context.get()->get(),
+ addressSpace.has_value() ? *addressSpace : 0);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return PointerType(context->getRef(), type);
+ },
+ "address_space"_a = nb::none(), nb::kw_only(),
+ "context"_a = nb::none());
+ c.def_prop_ro("address_space", [](PyType type) {
+ return mlirLLVMPointerTypeGetAddressSpace(type);
+ });
+ }
+};
- llvmStructType.def_classmethod(
- "get_opaque",
- [](const nb::object &cls, const std::string &name, MlirContext context) {
- return cls(mlirLLVMStructTypeOpaqueGet(
- context, mlirStringRefCreate(name.data(), name.size())));
- },
- "cls"_a, "name"_a, "context"_a = nb::none());
-
- llvmStructType.def(
- "set_body",
- [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
- MlirLogicalResult result = mlirLLVMStructTypeSetBody(
- self, elements.size(), elements.data(), packed);
- if (!mlirLogicalResultIsSuccess(result)) {
- throw nb::value_error(
- "Struct body already set to different content.");
- }
- },
- "elements"_a, nb::kw_only(), "packed"_a = false);
-
- llvmStructType.def_classmethod(
- "new_identified",
- [](const nb::object &cls, const std::string &name,
- const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
- return cls(mlirLLVMStructTypeIdentifiedNewGet(
- ctx, mlirStringRefCreate(name.data(), name.length()),
- elements.size(), elements.data(), packed));
- },
- "cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "context"_a = nb::none());
-
- llvmStructType.def_property_readonly(
- "name", [](MlirType type) -> std::optional<std::string> {
- if (mlirLLVMStructTypeIsLiteral(type))
- return std::nullopt;
-
- MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
- return StringRef(stringRef.data, stringRef.length).str();
- });
-
- llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object {
- // Don't crash in absence of a body.
- if (mlirLLVMStructTypeIsOpaque(type))
- return nb::none();
-
- nb::list body;
- for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
- ++i) {
- body.append(mlirLLVMStructTypeGetElementType(type, i));
- }
- return body;
- });
-
- llvmStructType.def_property_readonly(
- "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
-
- llvmStructType.def_property_readonly(
- "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
-
- //===--------------------------------------------------------------------===//
- // PointerType
- //===--------------------------------------------------------------------===//
-
- mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType,
- mlirLLVMPointerTypeGetTypeID)
- .def_classmethod(
- "get",
- [](const nb::object &cls, std::optional<unsigned> addressSpace,
- MlirContext context) {
- CollectDiagnosticsToStringScope scope(context);
- MlirType type = mlirLLVMPointerTypeGet(
- context, addressSpace.has_value() ? *addressSpace : 0);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "address_space"_a = nb::none(), nb::kw_only(),
- "context"_a = nb::none())
- .def_property_readonly("address_space", [](MlirType type) {
- return mlirLLVMPointerTypeGetAddressSpace(type);
- });
+static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
+ StructType::bind(m);
+ PointerType::bind(m);
m.def(
"translate_module_to_llvmir",
@@ -167,9 +194,13 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
// clang-format on
"module"_a, nb::rv_policy::take_ownership);
}
+} // namespace llvm
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsLLVM, m) {
m.doc() = "MLIR LLVM Dialect";
- populateDialectLLVMSubmodule(m);
+ python::mlir::llvm::populateDialectLLVMSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
index 18917416412c1..179cc32520e83 100644
--- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
@@ -8,34 +8,48 @@
#include "mlir-c/Dialect/NVGPU.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectNVGPUSubmodule(const nb::module_ &m) {
- auto nvgpuTensorMapDescriptorType = mlir_type_subclass(
- m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType);
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace nvgpu {
+struct TensorMapDescriptorType : PyConcreteType<TensorMapDescriptorType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsANVGPUTensorMapDescriptorType;
+ static constexpr const char *pyClassName = "TensorMapDescriptorType";
+ using PyConcreteType::PyConcreteType;
- nvgpuTensorMapDescriptorType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType tensorMemrefType, int swizzle,
- int l2promo, int oobFill, int interleave, MlirContext ctx) {
- return cls(mlirNVGPUTensorMapDescriptorTypeGet(
- ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
- },
- "Gets an instance of TensorMapDescriptorType in the same context",
- nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"),
- nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"),
- nb::arg("ctx") = nb::none());
-}
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &tensorMemrefType, int swizzle, int l2promo,
+ int oobFill, int interleave, DefaultingPyMlirContext context) {
+ return TensorMapDescriptorType(
+ context->getRef(), mlirNVGPUTensorMapDescriptorTypeGet(
+ context.get()->get(), tensorMemrefType,
+ swizzle, l2promo, oobFill, interleave));
+ },
+ "Gets an instance of TensorMapDescriptorType in the same context",
+ nb::arg("tensor_type"), nb::arg("swizzle"), nb::arg("l2promo"),
+ nb::arg("oob_fill"), nb::arg("interleave"),
+ nb::arg("context").none() = nb::none());
+ }
+};
+} // namespace nvgpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsNVGPU, m) {
m.doc() = "MLIR NVGPU dialect.";
- populateDialectNVGPUSubmodule(m);
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::nvgpu::TensorMapDescriptorType::
+ bind(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index 1acb41080f711..d2ed3b141d724 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -8,98 +8,160 @@
#include "mlir-c/Dialect/PDL.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectPDLSubmodule(const nanobind::module_ &m) {
- //===-------------------------------------------------------------------===//
- // PDLType
- //===-------------------------------------------------------------------===//
-
- auto pdlType = mlir_type_subclass(m, "PDLType", mlirTypeIsAPDLType);
-
- //===-------------------------------------------------------------------===//
- // AttributeType
- //===-------------------------------------------------------------------===//
-
- auto attributeType =
- mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType);
- attributeType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPDLAttributeTypeGet(ctx));
- },
- "Get an instance of AttributeType in given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // OperationType
- //===-------------------------------------------------------------------===//
-
- auto operationType =
- mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType);
- operationType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPDLOperationTypeGet(ctx));
- },
- "Get an instance of OperationType in given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // RangeType
- //===-------------------------------------------------------------------===//
-
- auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType);
- rangeType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType elementType) {
- return cls(mlirPDLRangeTypeGet(elementType));
- },
- "Gets an instance of RangeType in the same context as the provided "
- "element type.",
- nb::arg("cls"), nb::arg("element_type"));
- rangeType.def_property_readonly(
- "element_type",
- [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); },
- nb::sig(
- "def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")),
- "Get the element type.");
-
- //===-------------------------------------------------------------------===//
- // TypeType
- //===-------------------------------------------------------------------===//
-
- auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType);
- typeType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPDLTypeTypeGet(ctx));
- },
- "Get an instance of TypeType in given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // ValueType
- //===-------------------------------------------------------------------===//
-
- auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType);
- valueType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPDLValueTypeGet(ctx));
- },
- "Get an instance of TypeType in given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace pdl {
+
+//===-------------------------------------------------------------------===//
+// PDLType
+//===-------------------------------------------------------------------===//
+
+struct PDLType : PyConcreteType<PDLType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLType;
+ static constexpr const char *pyClassName = "PDLType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {}
+};
+
+//===-------------------------------------------------------------------===//
+// AttributeType
+//===-------------------------------------------------------------------===//
+
+struct AttributeType : PyConcreteType<AttributeType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLAttributeType;
+ static constexpr const char *pyClassName = "AttributeType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AttributeType(context->getRef(),
+ mlirPDLAttributeTypeGet(context.get()->get()));
+ },
+ "Get an instance of AttributeType in given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// OperationType
+//===-------------------------------------------------------------------===//
+
+struct OperationType : PyConcreteType<OperationType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLOperationType;
+ static constexpr const char *pyClassName = "OperationType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return OperationType(context->getRef(),
+ mlirPDLOperationTypeGet(context.get()->get()));
+ },
+ "Get an instance of OperationType in given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// RangeType
+//===-------------------------------------------------------------------===//
+
+struct RangeType : PyConcreteType<RangeType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLRangeType;
+ static constexpr const char *pyClassName = "RangeType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, DefaultingPyMlirContext context) {
+ return RangeType(context->getRef(), mlirPDLRangeTypeGet(elementType));
+ },
+ "Gets an instance of RangeType in the same context as the provided "
+ "element type.",
+ nb::arg("element_type"), nb::arg("context").none() = nb::none());
+ c.def_prop_ro(
+ "element_type",
+ [](PyType &type) {
+ return PyType(type.getContext(),
+ mlirPDLRangeTypeGetElementType(type));
+ },
+ nb::sig(
+ "def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")),
+ "Get the element type.");
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// TypeType
+//===-------------------------------------------------------------------===//
+
+struct TypeType : PyConcreteType<TypeType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLTypeType;
+ static constexpr const char *pyClassName = "TypeType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return TypeType(context->getRef(),
+ mlirPDLTypeTypeGet(context.get()->get()));
+ },
+ "Get an instance of TypeType in given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// ValueType
+//===-------------------------------------------------------------------===//
+
+struct ValueType : PyConcreteType<ValueType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLValueType;
+ static constexpr const char *pyClassName = "ValueType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return ValueType(context->getRef(),
+ mlirPDLValueTypeGet(context.get()->get()));
+ },
+ "Get an instance of TypeType in given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+static void populateDialectPDLSubmodule(nanobind::module_ &m) {
+ PDLType::bind(m);
+ AttributeType::bind(m);
+ OperationType::bind(m);
+ RangeType::bind(m);
+ TypeType::bind(m);
+ ValueType::bind(m);
}
+} // namespace pdl
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsPDL, m) {
m.doc() = "MLIR PDL dialect.";
- populateDialectPDLSubmodule(m);
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::pdl::populateDialectPDLSubmodule(
+ m);
}
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index a5220fcc00604..a1e0a281a708d 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -6,385 +6,485 @@
//
//===----------------------------------------------------------------------===//
-#include <cstdint>
#include <vector>
-#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Dialect/Quant.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectQuantSubmodule(const nb::module_ &m) {
- //===-------------------------------------------------------------------===//
- // QuantizedType
- //===-------------------------------------------------------------------===//
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace quant {
+//===-------------------------------------------------------------------===//
+// QuantizedType
+//===-------------------------------------------------------------------===//
- auto quantizedType =
- mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
- quantizedType.def_staticmethod(
- "default_minimum_for_integer",
- [](bool isSigned, unsigned integralWidth) {
- return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
- integralWidth);
- },
- "Default minimum value for the integer with the specified signedness and "
- "bit width.",
- nb::arg("is_signed"), nb::arg("integral_width"));
- quantizedType.def_staticmethod(
- "default_maximum_for_integer",
- [](bool isSigned, unsigned integralWidth) {
- return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
- integralWidth);
- },
- "Default maximum value for the integer with the specified signedness and "
- "bit width.",
- nb::arg("is_signed"), nb::arg("integral_width"));
- quantizedType.def_property_readonly(
- "expressed_type",
- [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
- "Type expressed by this quantized type.");
- quantizedType.def_property_readonly(
- "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
- "Flags of this quantized type (named accessors should be preferred to "
- "this)");
- quantizedType.def_property_readonly(
- "is_signed",
- [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
- "Signedness of this quantized type.");
- quantizedType.def_property_readonly(
- "storage_type",
- [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
- "Storage type backing this quantized type.");
- quantizedType.def_property_readonly(
- "storage_type_min",
- [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
- "The minimum value held by the storage type of this quantized type.");
- quantizedType.def_property_readonly(
- "storage_type_max",
- [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
- "The maximum value held by the storage type of this quantized type.");
- quantizedType.def_property_readonly(
- "storage_type_integral_width",
- [](MlirType type) {
- return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
- },
- "The bitwidth of the storage type of this quantized type.");
- quantizedType.def(
- "is_compatible_expressed_type",
- [](MlirType type, MlirType candidate) {
- return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
- },
- "Checks whether the candidate type can be expressed by this quantized "
- "type.",
- nb::arg("candidate"));
- quantizedType.def_property_readonly(
- "quantized_element_type",
- [](MlirType type) {
- return mlirQuantizedTypeGetQuantizedElementType(type);
- },
- "Element type of this quantized type expressed as quantized type.");
- quantizedType.def(
- "cast_from_storage_type",
- [](MlirType type, MlirType candidate) {
- MlirType castResult =
- mlirQuantizedTypeCastFromStorageType(type, candidate);
- if (!mlirTypeIsNull(castResult))
- return castResult;
- throw nb::type_error("Invalid cast.");
- },
- "Casts from a type based on the storage type of this quantized type to a "
- "corresponding type based on the quantized type. Raises TypeError if the "
- "cast is not valid.",
- nb::arg("candidate"));
- quantizedType.def_staticmethod(
- "cast_to_storage_type",
- [](MlirType type) {
- MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
- if (!mlirTypeIsNull(castResult))
- return castResult;
- throw nb::type_error("Invalid cast.");
- },
- "Casts from a type based on a quantized type to a corresponding type "
- "based on the storage type of this quantized type. Raises TypeError if "
- "the cast is not valid.",
- nb::arg("type"));
- quantizedType.def(
- "cast_from_expressed_type",
- [](MlirType type, MlirType candidate) {
- MlirType castResult =
- mlirQuantizedTypeCastFromExpressedType(type, candidate);
- if (!mlirTypeIsNull(castResult))
- return castResult;
- throw nb::type_error("Invalid cast.");
- },
- "Casts from a type based on the expressed type of this quantized type to "
- "a corresponding type based on the quantized type. Raises TypeError if "
- "the cast is not valid.",
- nb::arg("candidate"));
- quantizedType.def_staticmethod(
- "cast_to_expressed_type",
- [](MlirType type) {
- MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
- if (!mlirTypeIsNull(castResult))
- return castResult;
- throw nb::type_error("Invalid cast.");
- },
- "Casts from a type based on a quantized type to a corresponding type "
- "based on the expressed type of this quantized type. Raises TypeError if "
- "the cast is not valid.",
- nb::arg("type"));
- quantizedType.def(
- "cast_expressed_to_storage_type",
- [](MlirType type, MlirType candidate) {
- MlirType castResult =
- mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
- if (!mlirTypeIsNull(castResult))
- return castResult;
- throw nb::type_error("Invalid cast.");
- },
- "Casts from a type based on the expressed type of this quantized type to "
- "a corresponding type based on the storage type. Raises TypeError if the "
- "cast is not valid.",
- nb::arg("candidate"));
+struct QuantizedType : PyConcreteType<QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAQuantizedType;
+ static constexpr const char *pyClassName = "QuantizedType";
+ using PyConcreteType::PyConcreteType;
- quantizedType.get_class().attr("FLAG_SIGNED") =
- mlirQuantizedTypeGetSignedFlag();
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "default_minimum_for_integer",
+ [](bool isSigned, unsigned integralWidth) {
+ return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
+ integralWidth);
+ },
+ "Default minimum value for the integer with the specified signedness "
+ "and "
+ "bit width.",
+ nb::arg("is_signed"), nb::arg("integral_width"));
+ c.def_static(
+ "default_maximum_for_integer",
+ [](bool isSigned, unsigned integralWidth) {
+ return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
+ integralWidth);
+ },
+ "Default maximum value for the integer with the specified signedness "
+ "and "
+ "bit width.",
+ nb::arg("is_signed"), nb::arg("integral_width"));
+ c.def_prop_ro(
+ "expressed_type",
+ [](PyType type) {
+ return PyType(type.getContext(),
+ mlirQuantizedTypeGetExpressedType(type));
+ },
+ "Type expressed by this quantized type.");
+ c.def_prop_ro(
+ "flags",
+ [](const PyType &type) { return mlirQuantizedTypeGetFlags(type); },
+ "Flags of this quantized type (named accessors should be preferred to "
+ "this)");
+ c.def_prop_ro(
+ "is_signed",
+ [](const PyType &type) { return mlirQuantizedTypeIsSigned(type); },
+ "Signedness of this quantized type.");
+ c.def_prop_ro(
+ "storage_type",
+ [](PyType type) {
+ return PyType(type.getContext(),
+ mlirQuantizedTypeGetStorageType(type));
+ },
+ "Storage type backing this quantized type.");
+ c.def_prop_ro(
+ "storage_type_min",
+ [](const PyType &type) {
+ return mlirQuantizedTypeGetStorageTypeMin(type);
+ },
+ "The minimum value held by the storage type of this quantized type.");
+ c.def_prop_ro(
+ "storage_type_max",
+ [](const PyType &type) {
+ return mlirQuantizedTypeGetStorageTypeMax(type);
+ },
+ "The maximum value held by the storage type of this quantized type.");
+ c.def_prop_ro(
+ "storage_type_integral_width",
+ [](const PyType &type) {
+ return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
+ },
+ "The bitwidth of the storage type of this quantized type.");
+ c.def(
+ "is_compatible_expressed_type",
+ [](const PyType &type, const PyType &candidate) {
+ return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
+ },
+ "Checks whether the candidate type can be expressed by this quantized "
+ "type.",
+ nb::arg("candidate"));
+ c.def_prop_ro(
+ "quantized_element_type",
+ [](PyType type) {
+ return PyType(type.getContext(),
+ mlirQuantizedTypeGetQuantizedElementType(type));
+ },
+ "Element type of this quantized type expressed as quantized type.");
+ c.def(
+ "cast_from_storage_type",
+ [](PyType type, const PyType &candidate) {
+ MlirType castResult =
+ mlirQuantizedTypeCastFromStorageType(type, candidate);
+ if (!mlirTypeIsNull(castResult))
+ return QuantizedType(type.getContext(), castResult);
+ throw nb::type_error("Invalid cast.");
+ },
+ "Casts from a type based on the storage type of this quantized type to "
+ "a "
+ "corresponding type based on the quantized type. Raises TypeError if "
+ "the "
+ "cast is not valid.",
+ nb::arg("candidate"));
+ c.def_static(
+ "cast_to_storage_type",
+ [](const PyType &type) {
+ MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
+ if (!mlirTypeIsNull(castResult))
+ return castResult;
+ throw nb::type_error("Invalid cast.");
+ },
+ "Casts from a type based on a quantized type to a corresponding type "
+ "based on the storage type of this quantized type. Raises TypeError if "
+ "the cast is not valid.",
+ nb::arg("type"));
+ c.def(
+ "cast_from_expressed_type",
+ [](PyType type, const PyType &candidate) {
+ MlirType castResult =
+ mlirQuantizedTypeCastFromExpressedType(type, candidate);
+ if (!mlirTypeIsNull(castResult))
+ return PyType(type.getContext(), castResult);
+ throw nb::type_error("Invalid cast.");
+ },
+ "Casts from a type based on the expressed type of this quantized type "
+ "to "
+ "a corresponding type based on the quantized type. Raises TypeError if "
+ "the cast is not valid.",
+ nb::arg("candidate"));
+ c.def_static(
+ "cast_to_expressed_type",
+ [](const PyType &type) {
+ MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
+ if (!mlirTypeIsNull(castResult))
+ return castResult;
+ throw nb::type_error("Invalid cast.");
+ },
+ "Casts from a type based on a quantized type to a corresponding type "
+ "based on the expressed type of this quantized type. Raises TypeError "
+ "if "
+ "the cast is not valid.",
+ nb::arg("type"));
+ c.def(
+ "cast_expressed_to_storage_type",
+ [](PyType type, const PyType &candidate) {
+ MlirType castResult =
+ mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
+ if (!mlirTypeIsNull(castResult))
+ return PyType(type.getContext(), castResult);
+ throw nb::type_error("Invalid cast.");
+ },
+ "Casts from a type based on the expressed type of this quantized type "
+ "to "
+ "a corresponding type based on the storage type. Raises TypeError if "
+ "the "
+ "cast is not valid.",
+ nb::arg("candidate"));
+ }
+};
- //===-------------------------------------------------------------------===//
- // AnyQuantizedType
- //===-------------------------------------------------------------------===//
+//===-------------------------------------------------------------------===//
+// AnyQuantizedType
+//===-------------------------------------------------------------------===//
- auto anyQuantizedType =
- mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
- quantizedType.get_class());
- anyQuantizedType.def_classmethod(
- "get",
- [](const nb::object &cls, unsigned flags, MlirType storageType,
- MlirType expressedType, int64_t storageTypeMin,
- int64_t storageTypeMax) {
- return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
- storageTypeMin, storageTypeMax));
- },
- "Gets an instance of AnyQuantizedType in the same context as the "
- "provided storage type.",
- nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
- nb::arg("expressed_type"), nb::arg("storage_type_min"),
- nb::arg("storage_type_max"));
+struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAnyQuantizedType;
+ static constexpr const char *pyClassName = "AnyQuantizedType";
+ using PyConcreteType::PyConcreteType;
- //===-------------------------------------------------------------------===//
- // UniformQuantizedType
- //===-------------------------------------------------------------------===//
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](unsigned flags, const PyType &storageType,
+ const PyType &expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax, DefaultingPyMlirContext context) {
+ return AnyQuantizedType(
+ context->getRef(),
+ mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
+ storageTypeMin, storageTypeMax));
+ },
+ "Gets an instance of AnyQuantizedType in the same context as the "
+ "provided storage type.",
+ nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+ nb::arg("storage_type_min"), nb::arg("storage_type_max"),
+ nb::arg("context") = nb::none());
+ }
+};
- auto uniformQuantizedType = mlir_type_subclass(
- m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
- quantizedType.get_class());
- uniformQuantizedType.def_classmethod(
- "get",
- [](const nb::object &cls, unsigned flags, MlirType storageType,
- MlirType expressedType, double scale, int64_t zeroPoint,
- int64_t storageTypeMin, int64_t storageTypeMax) {
- return cls(mlirUniformQuantizedTypeGet(flags, storageType,
- expressedType, scale, zeroPoint,
- storageTypeMin, storageTypeMax));
- },
- "Gets an instance of UniformQuantizedType in the same context as the "
- "provided storage type.",
- nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
- nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"),
- nb::arg("storage_type_min"), nb::arg("storage_type_max"));
- uniformQuantizedType.def_property_readonly(
- "scale",
- [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
- "The scale designates the difference between the real values "
- "corresponding to consecutive quantized values differing by 1.");
- uniformQuantizedType.def_property_readonly(
- "zero_point",
- [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
- "The storage value corresponding to the real value 0 in the affine "
- "equation.");
- uniformQuantizedType.def_property_readonly(
- "is_fixed_point",
- [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
- "Fixed point values are real numbers divided by a scale.");
+//===-------------------------------------------------------------------===//
+// UniformQuantizedType
+//===-------------------------------------------------------------------===//
- //===-------------------------------------------------------------------===//
- // UniformQuantizedPerAxisType
- //===-------------------------------------------------------------------===//
- auto uniformQuantizedPerAxisType = mlir_type_subclass(
- m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
- quantizedType.get_class());
- uniformQuantizedPerAxisType.def_classmethod(
- "get",
- [](const nb::object &cls, unsigned flags, MlirType storageType,
- MlirType expressedType, std::vector<double> scales,
- std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
- int64_t storageTypeMin, int64_t storageTypeMax) {
- if (scales.size() != zeroPoints.size())
- throw nb::value_error(
- "Mismatching number of scales and zero points.");
- auto nDims = static_cast<intptr_t>(scales.size());
- return cls(mlirUniformQuantizedPerAxisTypeGet(
- flags, storageType, expressedType, nDims, scales.data(),
- zeroPoints.data(), quantizedDimension, storageTypeMin,
- storageTypeMax));
- },
- "Gets an instance of UniformQuantizedPerAxisType in the same context as "
- "the provided storage type.",
- nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
- nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
- nb::arg("quantized_dimension"), nb::arg("storage_type_min"),
- nb::arg("storage_type_max"));
- uniformQuantizedPerAxisType.def_property_readonly(
- "scales",
- [](MlirType type) {
- intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
- std::vector<double> scales;
- scales.reserve(nDim);
- for (intptr_t i = 0; i < nDim; ++i) {
- double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
- scales.push_back(scale);
- }
- return scales;
- },
- "The scales designate the difference between the real values "
- "corresponding to consecutive quantized values differing by 1. The ith "
- "scale corresponds to the ith slice in the quantized_dimension.");
- uniformQuantizedPerAxisType.def_property_readonly(
- "zero_points",
- [](MlirType type) {
- intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
- std::vector<int64_t> zeroPoints;
- zeroPoints.reserve(nDim);
- for (intptr_t i = 0; i < nDim; ++i) {
- int64_t zeroPoint =
- mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
- zeroPoints.push_back(zeroPoint);
- }
- return zeroPoints;
- },
- "the storage values corresponding to the real value 0 in the affine "
- "equation. The ith zero point corresponds to the ith slice in the "
- "quantized_dimension.");
- uniformQuantizedPerAxisType.def_property_readonly(
- "quantized_dimension",
- [](MlirType type) {
- return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
- },
- "Specifies the dimension of the shape that the scales and zero points "
- "correspond to.");
- uniformQuantizedPerAxisType.def_property_readonly(
- "is_fixed_point",
- [](MlirType type) {
- return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
- },
- "Fixed point values are real numbers divided by a scale.");
+struct UniformQuantizedType
+ : PyConcreteType<UniformQuantizedType, QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedType;
+ static constexpr const char *pyClassName = "UniformQuantizedType";
+ using PyConcreteType::PyConcreteType;
- //===-------------------------------------------------------------------===//
- // UniformQuantizedSubChannelType
- //===-------------------------------------------------------------------===//
- auto uniformQuantizedSubChannelType = mlir_type_subclass(
- m, "UniformQuantizedSubChannelType",
- mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
- uniformQuantizedSubChannelType.def_classmethod(
- "get",
- [](const nb::object &cls, unsigned flags, MlirType storageType,
- MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
- std::vector<int32_t> quantizedDimensions,
- std::vector<int64_t> blockSizes, int64_t storageTypeMin,
- int64_t storageTypeMax) {
- return cls(mlirUniformQuantizedSubChannelTypeGet(
- flags, storageType, expressedType, scales, zeroPoints,
- static_cast<intptr_t>(blockSizes.size()),
- quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
- storageTypeMax));
- },
- "Gets an instance of UniformQuantizedSubChannel in the same context as "
- "the provided storage type.",
- nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
- nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
- nb::arg("quantized_dimensions"), nb::arg("block_sizes"),
- nb::arg("storage_type_min"), nb::arg("storage_type_max"));
- uniformQuantizedSubChannelType.def_property_readonly(
- "quantized_dimensions",
- [](MlirType type) {
- intptr_t nDim =
- mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
- std::vector<int32_t> quantizedDimensions;
- quantizedDimensions.reserve(nDim);
- for (intptr_t i = 0; i < nDim; ++i) {
- quantizedDimensions.push_back(
- mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i));
- }
- return quantizedDimensions;
- },
- "Gets the quantized dimensions. Each element in the returned list "
- "represents an axis of the quantized data tensor that has a specified "
- "block size. The order of elements corresponds to the order of block "
- "sizes returned by 'block_sizes' method. It means that the data tensor "
- "is quantized along the i-th dimension in the returned list using the "
- "i-th block size from block_sizes method.");
- uniformQuantizedSubChannelType.def_property_readonly(
- "block_sizes",
- [](MlirType type) {
- intptr_t nDim =
- mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
- std::vector<int64_t> blockSizes;
- blockSizes.reserve(nDim);
- for (intptr_t i = 0; i < nDim; ++i) {
- blockSizes.push_back(
- mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
- }
- return blockSizes;
- },
- "Gets the block sizes for the quantized dimensions. The i-th element in "
- "the returned list corresponds to the block size for the i-th dimension "
- "in the list returned by quantized_dimensions method.");
- uniformQuantizedSubChannelType.def_property_readonly(
- "scales",
- [](MlirType type) -> MlirAttribute {
- return mlirUniformQuantizedSubChannelTypeGetScales(type);
- },
- "The scales of the quantized type.");
- uniformQuantizedSubChannelType.def_property_readonly(
- "zero_points",
- [](MlirType type) -> MlirAttribute {
- return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
- },
- "The zero points of the quantized type.");
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](unsigned flags, const PyType &storageType,
+ const PyType &expressedType, double scale, int64_t zeroPoint,
+ int64_t storageTypeMin, int64_t storageTypeMax,
+ DefaultingPyMlirContext context) {
+ return UniformQuantizedType(
+ context->getRef(),
+ mlirUniformQuantizedTypeGet(flags, storageType, expressedType,
+ scale, zeroPoint, storageTypeMin,
+ storageTypeMax));
+ },
+ "Gets an instance of UniformQuantizedType in the same context as the "
+ "provided storage type.",
+ nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+ nb::arg("scale"), nb::arg("zero_point"), nb::arg("storage_type_min"),
+ nb::arg("storage_type_max"), nb::arg("context") = nb::none());
+ c.def_prop_ro(
+ "scale",
+ [](const PyType &type) {
+ return mlirUniformQuantizedTypeGetScale(type);
+ },
+ "The scale designates the difference between the real values "
+ "corresponding to consecutive quantized values differing by 1.");
+ c.def_prop_ro(
+ "zero_point",
+ [](const PyType &type) {
+ return mlirUniformQuantizedTypeGetZeroPoint(type);
+ },
+ "The storage value corresponding to the real value 0 in the affine "
+ "equation.");
+ c.def_prop_ro(
+ "is_fixed_point",
+ [](const PyType &type) {
+ return mlirUniformQuantizedTypeIsFixedPoint(type);
+ },
+ "Fixed point values are real numbers divided by a scale.");
+ }
+};
- //===-------------------------------------------------------------------===//
- // CalibratedQuantizedType
- //===-------------------------------------------------------------------===//
+//===-------------------------------------------------------------------===//
+// UniformQuantizedPerAxisType
+//===-------------------------------------------------------------------===//
- auto calibratedQuantizedType = mlir_type_subclass(
- m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
- quantizedType.get_class());
- calibratedQuantizedType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType expressedType, double min,
- double max) {
- return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
- },
- "Gets an instance of CalibratedQuantizedType in the same context as the "
- "provided expressed type.",
- nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"),
- nb::arg("max"));
- calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
- return mlirCalibratedQuantizedTypeGetMin(type);
- });
- calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
- return mlirCalibratedQuantizedTypeGetMax(type);
- });
+struct UniformQuantizedPerAxisType
+ : PyConcreteType<UniformQuantizedPerAxisType, QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAUniformQuantizedPerAxisType;
+ static constexpr const char *pyClassName = "UniformQuantizedPerAxisType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](unsigned flags, const PyType &storageType,
+ const PyType &expressedType, std::vector<double> scales,
+ std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
+ int64_t storageTypeMin, int64_t storageTypeMax,
+ DefaultingPyMlirContext context) {
+ if (scales.size() != zeroPoints.size())
+ throw nb::value_error(
+ "Mismatching number of scales and zero points.");
+ auto nDims = static_cast<intptr_t>(scales.size());
+ return UniformQuantizedPerAxisType(
+ context->getRef(),
+ mlirUniformQuantizedPerAxisTypeGet(
+ flags, storageType, expressedType, nDims, scales.data(),
+ zeroPoints.data(), quantizedDimension, storageTypeMin,
+ storageTypeMax));
+ },
+ "Gets an instance of UniformQuantizedPerAxisType in the same context "
+ "as "
+ "the provided storage type.",
+ nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+ nb::arg("scales"), nb::arg("zero_points"),
+ nb::arg("quantized_dimension"), nb::arg("storage_type_min"),
+ nb::arg("storage_type_max"), nb::arg("context") = nb::none());
+ c.def_prop_ro(
+ "scales",
+ [](const PyType &type) {
+ intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
+ std::vector<double> scales;
+ scales.reserve(nDim);
+ for (intptr_t i = 0; i < nDim; ++i) {
+ double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
+ scales.push_back(scale);
+ }
+ return scales;
+ },
+ "The scales designate the difference between the real values "
+ "corresponding to consecutive quantized values differing by 1. The ith "
+ "scale corresponds to the ith slice in the quantized_dimension.");
+ c.def_prop_ro(
+ "zero_points",
+ [](const PyType &type) {
+ intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
+ std::vector<int64_t> zeroPoints;
+ zeroPoints.reserve(nDim);
+ for (intptr_t i = 0; i < nDim; ++i) {
+ int64_t zeroPoint =
+ mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
+ zeroPoints.push_back(zeroPoint);
+ }
+ return zeroPoints;
+ },
+ "the storage values corresponding to the real value 0 in the affine "
+ "equation. The ith zero point corresponds to the ith slice in the "
+ "quantized_dimension.");
+ c.def_prop_ro(
+ "quantized_dimension",
+ [](const PyType &type) {
+ return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
+ },
+ "Specifies the dimension of the shape that the scales and zero points "
+ "correspond to.");
+ c.def_prop_ro(
+ "is_fixed_point",
+ [](const PyType &type) {
+ return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
+ },
+ "Fixed point values are real numbers divided by a scale.");
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// UniformQuantizedSubChannelType
+//===-------------------------------------------------------------------===//
+
+struct UniformQuantizedSubChannelType
+ : PyConcreteType<UniformQuantizedSubChannelType, QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAUniformQuantizedSubChannelType;
+ static constexpr const char *pyClassName = "UniformQuantizedSubChannelType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](unsigned flags, const PyType &storageType,
+ const PyType &expressedType, MlirAttribute scales,
+ MlirAttribute zeroPoints, std::vector<int32_t> quantizedDimensions,
+ std::vector<int64_t> blockSizes, int64_t storageTypeMin,
+ int64_t storageTypeMax, DefaultingPyMlirContext context) {
+ return UniformQuantizedSubChannelType(
+ context->getRef(),
+ mlirUniformQuantizedSubChannelTypeGet(
+ flags, storageType, expressedType, scales, zeroPoints,
+ static_cast<intptr_t>(blockSizes.size()),
+ quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
+ storageTypeMax));
+ },
+ "Gets an instance of UniformQuantizedSubChannel in the same context as "
+ "the provided storage type.",
+ nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+ nb::arg("scales"), nb::arg("zero_points"),
+ nb::arg("quantized_dimensions"), nb::arg("block_sizes"),
+ nb::arg("storage_type_min"), nb::arg("storage_type_max"),
+ nb::arg("context") = nb::none());
+ c.def_prop_ro(
+ "quantized_dimensions",
+ [](const PyType &type) {
+ intptr_t nDim =
+ mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
+ std::vector<int32_t> quantizedDimensions;
+ quantizedDimensions.reserve(nDim);
+ for (intptr_t i = 0; i < nDim; ++i) {
+ quantizedDimensions.push_back(
+ mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type,
+ i));
+ }
+ return quantizedDimensions;
+ },
+ "Gets the quantized dimensions. Each element in the returned list "
+ "represents an axis of the quantized data tensor that has a specified "
+ "block size. The order of elements corresponds to the order of block "
+ "sizes returned by 'block_sizes' method. It means that the data tensor "
+ "is quantized along the i-th dimension in the returned list using the "
+ "i-th block size from block_sizes method.");
+ c.def_prop_ro(
+ "block_sizes",
+ [](const PyType &type) {
+ intptr_t nDim =
+ mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
+ std::vector<int64_t> blockSizes;
+ blockSizes.reserve(nDim);
+ for (intptr_t i = 0; i < nDim; ++i) {
+ blockSizes.push_back(
+ mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
+ }
+ return blockSizes;
+ },
+ "Gets the block sizes for the quantized dimensions. The i-th element "
+ "in "
+ "the returned list corresponds to the block size for the i-th "
+ "dimension "
+ "in the list returned by quantized_dimensions method.");
+ c.def_prop_ro(
+ "scales",
+ [](const PyType &type) -> MlirAttribute {
+ return mlirUniformQuantizedSubChannelTypeGetScales(type);
+ },
+ "The scales of the quantized type.");
+ c.def_prop_ro(
+ "zero_points",
+ [](const PyType &type) -> MlirAttribute {
+ return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
+ },
+ "The zero points of the quantized type.");
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// CalibratedQuantizedType
+//===-------------------------------------------------------------------===//
+
+struct CalibratedQuantizedType
+ : PyConcreteType<CalibratedQuantizedType, QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsACalibratedQuantizedType;
+ static constexpr const char *pyClassName = "CalibratedQuantizedType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &expressedType, double min, double max,
+ DefaultingPyMlirContext context) {
+ return CalibratedQuantizedType(
+ context->getRef(),
+ mlirCalibratedQuantizedTypeGet(expressedType, min, max));
+ },
+ "Gets an instance of CalibratedQuantizedType in the same context as "
+ "the "
+ "provided expressed type.",
+ nb::arg("expressed_type"), nb::arg("min"), nb::arg("max"),
+ nb::arg("context") = nb::none());
+ c.def_prop_ro("min", [](const PyType &type) {
+ return mlirCalibratedQuantizedTypeGetMin(type);
+ });
+ c.def_prop_ro("max", [](const PyType &type) {
+ return mlirCalibratedQuantizedTypeGetMax(type);
+ });
+ }
+};
+
+static void populateDialectQuantSubmodule(nb::module_ &m) {
+ QuantizedType::bind(m);
+
+ // Set the FLAG_SIGNED class attribute after binding QuantizedType
+ auto quantizedTypeClass = m.attr("QuantizedType");
+ quantizedTypeClass.attr("FLAG_SIGNED") = mlirQuantizedTypeGetSignedFlag();
+
+ AnyQuantizedType::bind(m);
+ UniformQuantizedType::bind(m);
+ UniformQuantizedPerAxisType::bind(m);
+ UniformQuantizedSubChannelType::bind(m);
+ CalibratedQuantizedType::bind(m);
}
+} // namespace quant
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsQuant, m) {
m.doc() = "MLIR Quantization dialect";
- populateDialectQuantSubmodule(m);
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::quant::
+ populateDialectQuantSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index a87918a05b126..39490155d5216 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -13,44 +13,77 @@
#include "mlir-c/Support.h"
#include "mlir-c/Target/ExportSMTLIB.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
-
using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectSMTSubmodule(nanobind::module_ &m) {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace smt {
+struct BoolType : PyConcreteType<BoolType> {
+ static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABool;
+ static constexpr const char *pyClassName = "BoolType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return BoolType(context->getRef(),
+ mlirSMTTypeGetBool(context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+struct BitVectorType : PyConcreteType<BitVectorType> {
+ static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABitVector;
+ static constexpr const char *pyClassName = "BitVectorType";
+ using PyConcreteType::PyConcreteType;
- auto smtBoolType =
- mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
- .def_staticmethod(
- "get",
- [](MlirContext context) { return mlirSMTTypeGetBool(context); },
- "context"_a = nb::none());
- auto smtBitVectorType =
- mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
- .def_staticmethod(
- "get",
- [](int32_t width, MlirContext context) {
- return mlirSMTTypeGetBitVector(context, width);
- },
- "width"_a, "context"_a = nb::none());
- auto smtIntType =
- mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
- .def_staticmethod(
- "get",
- [](MlirContext context) { return mlirSMTTypeGetInt(context); },
- "context"_a = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](int32_t width, DefaultingPyMlirContext context) {
+ return BitVectorType(
+ context->getRef(),
+ mlirSMTTypeGetBitVector(context.get()->get(), width));
+ },
+ nb::arg("width"), nb::arg("context").none() = nb::none());
+ }
+};
+
+struct IntType : PyConcreteType<IntType> {
+ static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsAInt;
+ static constexpr const char *pyClassName = "IntType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return IntType(context->getRef(),
+ mlirSMTTypeGetInt(context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+static void populateDialectSMTSubmodule(nanobind::module_ &m) {
+ BoolType::bind(m);
+ BitVectorType::bind(m);
+ IntType::bind(m);
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
- mlir::python::CollectDiagnosticsToStringScope scope(
- mlirOperationGetContext(module));
+ CollectDiagnosticsToStringScope scope(mlirOperationGetContext(module));
PyPrintAccumulator printAccum;
MlirLogicalResult result = mlirTranslateOperationToSMTLIB(
module, printAccum.getCallback(), printAccum.getUserData(),
@@ -80,9 +113,13 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) {
"module"_a, "inline_single_use_values"_a = false,
"indent_let_body"_a = false);
}
+} // namespace smt
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsSMT, m) {
m.doc() = "MLIR SMT Dialect";
- populateDialectSMTSubmodule(m);
+ python::mlir::smt::populateDialectSMTSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 00b65ee9745dc..6ec58dd88d24f 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -12,137 +12,179 @@
#include "mlir-c/AffineMap.h"
#include "mlir-c/Dialect/SparseTensor.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectSparseTensorSubmodule(const nb::module_ &m) {
- nb::enum_<MlirSparseTensorLevelFormat>(m, "LevelFormat", nb::is_arithmetic(),
- nb::is_flag())
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace sparse_tensor {
+
+enum PySparseTensorLevelFormat : std::underlying_type_t<
+ MlirSparseTensorLevelFormat> {
+ MLIR_SPARSE_TENSOR_LEVEL_DENSE = MLIR_SPARSE_TENSOR_LEVEL_DENSE,
+ MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M,
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED,
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = MLIR_SPARSE_TENSOR_LEVEL_SINGLETON,
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED =
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED
+};
+
+enum PySparseTensorLevelPropertyNondefault : std::underlying_type_t<
+ MlirSparseTensorLevelPropertyNondefault> {
+ MLIR_SPARSE_PROPERTY_NON_ORDERED = MLIR_SPARSE_PROPERTY_NON_ORDERED,
+ MLIR_SPARSE_PROPERTY_NON_UNIQUE = MLIR_SPARSE_PROPERTY_NON_UNIQUE,
+ MLIR_SPARSE_PROPERTY_SOA = MLIR_SPARSE_PROPERTY_SOA,
+};
+
+struct EncodingAttr : PyConcreteAttribute<EncodingAttr> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirAttributeIsASparseTensorEncodingAttr;
+ static constexpr const char *pyClassName = "EncodingAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<MlirSparseTensorLevelType> lvlTypes,
+ std::optional<MlirAffineMap> dimToLvl,
+ std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
+ std::optional<MlirAttribute> explicitVal,
+ std::optional<MlirAttribute> implicitVal,
+ DefaultingPyMlirContext context) {
+ return EncodingAttr(
+ context->getRef(),
+ mlirSparseTensorEncodingAttrGet(
+ context.get()->get(), lvlTypes.size(), lvlTypes.data(),
+ dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
+ lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
+ crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
+ implicitVal ? *implicitVal : MlirAttribute{nullptr}));
+ },
+ nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(),
+ nb::arg("lvl_to_dim").none(), nb::arg("pos_width"),
+ nb::arg("crd_width"), nb::arg("explicit_val") = nb::none(),
+ nb::arg("implicit_val") = nb::none(), nb::arg("context") = nb::none(),
+ "Gets a sparse_tensor.encoding from parameters.");
+
+ c.def_static(
+ "build_level_type",
+ [](PySparseTensorLevelFormat lvlFmt,
+ const std::vector<PySparseTensorLevelPropertyNondefault> &properties,
+ unsigned n, unsigned m) {
+ std::vector<MlirSparseTensorLevelPropertyNondefault> props;
+ props.reserve(properties.size());
+ for (auto prop : properties) {
+ props.push_back(
+ static_cast<MlirSparseTensorLevelPropertyNondefault>(prop));
+ }
+ return mlirSparseTensorEncodingAttrBuildLvlType(
+ static_cast<MlirSparseTensorLevelFormat>(lvlFmt), props.data(),
+ props.size(), n, m);
+ },
+ nb::arg("lvl_fmt"),
+ nb::arg("properties") =
+ std::vector<PySparseTensorLevelPropertyNondefault>(),
+ nb::arg("n") = 0, nb::arg("m") = 0,
+ "Builds a sparse_tensor.encoding.level_type from parameters.");
+
+ c.def_prop_ro("lvl_types", [](MlirAttribute self) {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ std::vector<MlirSparseTensorLevelType> ret;
+ ret.reserve(lvlRank);
+ for (int l = 0; l < lvlRank; ++l)
+ ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
+ return ret;
+ });
+
+ c.def_prop_ro(
+ "dim_to_lvl", [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+ MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
+ if (mlirAffineMapIsNull(ret))
+ return {};
+ return ret;
+ });
+
+ c.def_prop_ro(
+ "lvl_to_dim", [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+ MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
+ if (mlirAffineMapIsNull(ret))
+ return {};
+ return ret;
+ });
+
+ c.def_prop_ro("pos_width", mlirSparseTensorEncodingAttrGetPosWidth);
+ c.def_prop_ro("crd_width", mlirSparseTensorEncodingAttrGetCrdWidth);
+
+ c.def_prop_ro(
+ "explicit_val", [](MlirAttribute self) -> std::optional<MlirAttribute> {
+ MlirAttribute ret = mlirSparseTensorEncodingAttrGetExplicitVal(self);
+ if (mlirAttributeIsNull(ret))
+ return {};
+ return ret;
+ });
+
+ c.def_prop_ro(
+ "implicit_val", [](MlirAttribute self) -> std::optional<MlirAttribute> {
+ MlirAttribute ret = mlirSparseTensorEncodingAttrGetImplicitVal(self);
+ if (mlirAttributeIsNull(ret))
+ return {};
+ return ret;
+ });
+
+ c.def_prop_ro("structured_n", [](MlirAttribute self) -> unsigned {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ return mlirSparseTensorEncodingAttrGetStructuredN(
+ mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
+ });
+
+ c.def_prop_ro("structured_m", [](MlirAttribute self) -> unsigned {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ return mlirSparseTensorEncodingAttrGetStructuredM(
+ mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
+ });
+
+ c.def_prop_ro("lvl_formats_enum", [](MlirAttribute self) {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ std::vector<PySparseTensorLevelFormat> ret;
+ ret.reserve(lvlRank);
+
+ for (int l = 0; l < lvlRank; l++)
+ ret.push_back(static_cast<PySparseTensorLevelFormat>(
+ mlirSparseTensorEncodingAttrGetLvlFmt(self, l)));
+ return ret;
+ });
+ }
+};
+
+static void populateDialectSparseTensorSubmodule(nb::module_ &m) {
+ nb::enum_<PySparseTensorLevelFormat>(m, "LevelFormat", nb::is_arithmetic(),
+ nb::is_flag())
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
.value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
.value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
.value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON)
.value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED);
- nb::enum_<MlirSparseTensorLevelPropertyNondefault>(m, "LevelProperty")
+ nb::enum_<PySparseTensorLevelPropertyNondefault>(m, "LevelProperty")
.value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED)
.value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE)
.value("soa", MLIR_SPARSE_PROPERTY_SOA);
- mlir_attribute_subclass(m, "EncodingAttr",
- mlirAttributeIsASparseTensorEncodingAttr)
- .def_classmethod(
- "get",
- [](const nb::object &cls,
- std::vector<MlirSparseTensorLevelType> lvlTypes,
- std::optional<MlirAffineMap> dimToLvl,
- std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
- std::optional<MlirAttribute> explicitVal,
- std::optional<MlirAttribute> implicitVal, MlirContext context) {
- return cls(mlirSparseTensorEncodingAttrGet(
- context, lvlTypes.size(), lvlTypes.data(),
- dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
- lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
- crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
- implicitVal ? *implicitVal : MlirAttribute{nullptr}));
- },
- nb::arg("cls"), nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(),
- nb::arg("lvl_to_dim").none(), nb::arg("pos_width"),
- nb::arg("crd_width"), nb::arg("explicit_val") = nb::none(),
- nb::arg("implicit_val") = nb::none(), nb::arg("context") = nb::none(),
- "Gets a sparse_tensor.encoding from parameters.")
- .def_classmethod(
- "build_level_type",
- [](const nb::object &cls, MlirSparseTensorLevelFormat lvlFmt,
- const std::vector<MlirSparseTensorLevelPropertyNondefault>
- &properties,
- unsigned n, unsigned m) {
- return mlirSparseTensorEncodingAttrBuildLvlType(
- lvlFmt, properties.data(), properties.size(), n, m);
- },
- nb::arg("cls"), nb::arg("lvl_fmt"),
- nb::arg("properties") =
- std::vector<MlirSparseTensorLevelPropertyNondefault>(),
- nb::arg("n") = 0, nb::arg("m") = 0,
- "Builds a sparse_tensor.encoding.level_type from parameters.")
- .def_property_readonly(
- "lvl_types",
- [](MlirAttribute self) {
- const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- std::vector<MlirSparseTensorLevelType> ret;
- ret.reserve(lvlRank);
- for (int l = 0; l < lvlRank; ++l)
- ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
- return ret;
- })
- .def_property_readonly(
- "dim_to_lvl",
- [](MlirAttribute self) -> std::optional<MlirAffineMap> {
- MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
- if (mlirAffineMapIsNull(ret))
- return {};
- return ret;
- })
- .def_property_readonly(
- "lvl_to_dim",
- [](MlirAttribute self) -> std::optional<MlirAffineMap> {
- MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
- if (mlirAffineMapIsNull(ret))
- return {};
- return ret;
- })
- .def_property_readonly("pos_width",
- mlirSparseTensorEncodingAttrGetPosWidth)
- .def_property_readonly("crd_width",
- mlirSparseTensorEncodingAttrGetCrdWidth)
- .def_property_readonly(
- "explicit_val",
- [](MlirAttribute self) -> std::optional<MlirAttribute> {
- MlirAttribute ret =
- mlirSparseTensorEncodingAttrGetExplicitVal(self);
- if (mlirAttributeIsNull(ret))
- return {};
- return ret;
- })
- .def_property_readonly(
- "implicit_val",
- [](MlirAttribute self) -> std::optional<MlirAttribute> {
- MlirAttribute ret =
- mlirSparseTensorEncodingAttrGetImplicitVal(self);
- if (mlirAttributeIsNull(ret))
- return {};
- return ret;
- })
- .def_property_readonly(
- "structured_n",
- [](MlirAttribute self) -> unsigned {
- const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- return mlirSparseTensorEncodingAttrGetStructuredN(
- mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
- })
- .def_property_readonly(
- "structured_m",
- [](MlirAttribute self) -> unsigned {
- const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- return mlirSparseTensorEncodingAttrGetStructuredM(
- mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
- })
- .def_property_readonly("lvl_formats_enum", [](MlirAttribute self) {
- const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- std::vector<MlirSparseTensorLevelFormat> ret;
- ret.reserve(lvlRank);
- for (int l = 0; l < lvlRank; l++)
- ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l));
- return ret;
- });
+ EncodingAttr::bind(m);
}
+} // namespace sparse_tensor
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsSparseTensor, m) {
m.doc() = "MLIR SparseTensor dialect.";
- populateDialectSparseTensorSubmodule(m);
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::sparse_tensor::
+ populateDialectSparseTensorSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 150c69953d960..f42ebd004d09f 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -11,112 +11,165 @@
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectTransformSubmodule(const nb::module_ &m) {
- //===-------------------------------------------------------------------===//
- // AnyOpType
- //===-------------------------------------------------------------------===//
-
- auto anyOpType =
- mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
- mlirTransformAnyOpTypeGetTypeID);
- anyOpType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirTransformAnyOpTypeGet(ctx));
- },
- "Get an instance of AnyOpType in the given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // AnyParamType
- //===-------------------------------------------------------------------===//
-
- auto anyParamType =
- mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
- mlirTransformAnyParamTypeGetTypeID);
- anyParamType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirTransformAnyParamTypeGet(ctx));
- },
- "Get an instance of AnyParamType in the given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // AnyValueType
- //===-------------------------------------------------------------------===//
-
- auto anyValueType =
- mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
- mlirTransformAnyValueTypeGetTypeID);
- anyValueType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirTransformAnyValueTypeGet(ctx));
- },
- "Get an instance of AnyValueType in the given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // OperationType
- //===-------------------------------------------------------------------===//
-
- auto operationType =
- mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
- mlirTransformOperationTypeGetTypeID);
- operationType.def_classmethod(
- "get",
- [](const nb::object &cls, const std::string &operationName,
- MlirContext ctx) {
- MlirStringRef cOperationName =
- mlirStringRefCreate(operationName.data(), operationName.size());
- return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
- },
- "Get an instance of OperationType for the given kind in the given "
- "context",
- nb::arg("cls"), nb::arg("operation_name"),
- nb::arg("context") = nb::none());
- operationType.def_property_readonly(
- "operation_name",
- [](MlirType type) {
- MlirStringRef operationName =
- mlirTransformOperationTypeGetOperationName(type);
- return nb::str(operationName.data, operationName.length);
- },
- "Get the name of the payload operation accepted by the handle.");
-
- //===-------------------------------------------------------------------===//
- // ParamType
- //===-------------------------------------------------------------------===//
-
- auto paramType =
- mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
- mlirTransformParamTypeGetTypeID);
- paramType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType type, MlirContext ctx) {
- return cls(mlirTransformParamTypeGet(ctx, type));
- },
- "Get an instance of ParamType for the given type in the given context.",
- nb::arg("cls"), nb::arg("type"), nb::arg("context") = nb::none());
- paramType.def_property_readonly(
- "type",
- [](MlirType type) {
- MlirType paramType = mlirTransformParamTypeGetType(type);
- return paramType;
- },
- "Get the type this ParamType is associated with.");
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace transform {
+//===-------------------------------------------------------------------===//
+// AnyOpType
+//===-------------------------------------------------------------------===//
+
+struct AnyOpType : PyConcreteType<AnyOpType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyOpType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTransformAnyOpTypeGetTypeID;
+ static constexpr const char *pyClassName = "AnyOpType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AnyOpType(context->getRef(),
+ mlirTransformAnyOpTypeGet(context.get()->get()));
+ },
+ "Get an instance of AnyOpType in the given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// AnyParamType
+//===-------------------------------------------------------------------===//
+
+struct AnyParamType : PyConcreteType<AnyParamType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyParamType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTransformAnyParamTypeGetTypeID;
+ static constexpr const char *pyClassName = "AnyParamType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AnyParamType(context->getRef(), mlirTransformAnyParamTypeGet(
+ context.get()->get()));
+ },
+ "Get an instance of AnyParamType in the given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// AnyValueType
+//===-------------------------------------------------------------------===//
+
+struct AnyValueType : PyConcreteType<AnyValueType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyValueType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTransformAnyValueTypeGetTypeID;
+ static constexpr const char *pyClassName = "AnyValueType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AnyValueType(context->getRef(), mlirTransformAnyValueTypeGet(
+ context.get()->get()));
+ },
+ "Get an instance of AnyValueType in the given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// OperationType
+//===-------------------------------------------------------------------===//
+
+struct OperationType : PyConcreteType<OperationType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsATransformOperationType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTransformOperationTypeGetTypeID;
+ static constexpr const char *pyClassName = "OperationType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &operationName, DefaultingPyMlirContext context) {
+ MlirStringRef cOperationName =
+ mlirStringRefCreate(operationName.data(), operationName.size());
+ return OperationType(context->getRef(),
+ mlirTransformOperationTypeGet(
+ context.get()->get(), cOperationName));
+ },
+ "Get an instance of OperationType for the given kind in the given "
+ "context",
+ nb::arg("operation_name"), nb::arg("context").none() = nb::none());
+ c.def_prop_ro(
+ "operation_name",
+ [](const PyType &type) {
+ MlirStringRef operationName =
+ mlirTransformOperationTypeGetOperationName(type);
+ return nb::str(operationName.data, operationName.length);
+ },
+ "Get the name of the payload operation accepted by the handle.");
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// ParamType
+//===-------------------------------------------------------------------===//
+
+struct ParamType : PyConcreteType<ParamType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformParamType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTransformParamTypeGetTypeID;
+ static constexpr const char *pyClassName = "ParamType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &type, DefaultingPyMlirContext context) {
+ return ParamType(context->getRef(), mlirTransformParamTypeGet(
+ context.get()->get(), type));
+ },
+ "Get an instance of ParamType for the given type in the given context.",
+ nb::arg("type"), nb::arg("context").none() = nb::none());
+ c.def_prop_ro(
+ "type",
+ [](PyType type) {
+ return PyType(type.getContext(), mlirTransformParamTypeGetType(type));
+ },
+ "Get the type this ParamType is associated with.");
+ }
+};
+
+static void populateDialectTransformSubmodule(nb::module_ &m) {
+ AnyOpType::bind(m);
+ AnyParamType::bind(m);
+ AnyValueType::bind(m);
+ OperationType::bind(m);
+ ParamType::bind(m);
}
+} // namespace transform
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsTransform, m) {
m.doc() = "MLIR Transform dialect.";
- populateDialectTransformSubmodule(m);
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::transform::
+ populateDialectTransformSubmodule(m);
}
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8d045cad7a4a3..b4d19878056db 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -43,8 +43,9 @@ def __init__(
self.parent = parent
self.children = children if children is not None else []
- at ir.register_value_caster(AnyOpType.get_static_typeid())
- at ir.register_value_caster(OperationType.get_static_typeid())
+
+ at ir.register_value_caster(AnyOpType.static_typeid)
+ at ir.register_value_caster(OperationType.static_typeid)
class OpHandle(Handle):
"""
Wrapper around a transform operation handle with methods to chain further
@@ -132,8 +133,8 @@ def print(self, name: Optional[str] = None) -> "OpHandle":
return self
- at ir.register_value_caster(AnyParamType.get_static_typeid())
- at ir.register_value_caster(ParamType.get_static_typeid())
+ at ir.register_value_caster(AnyParamType.static_typeid)
+ at ir.register_value_caster(ParamType.static_typeid)
class ParamHandle(Handle):
"""Wrapper around a transform param handle."""
@@ -147,7 +148,7 @@ def __init__(
super().__init__(v, parent=parent, children=children)
- at ir.register_value_caster(AnyValueType.get_static_typeid())
+ at ir.register_value_caster(AnyValueType.static_typeid)
class ValueHandle(Handle):
"""
Wrapper around a transform value handle with methods to chain further
diff --git a/mlir/test/python/dialects/pdl_types.py b/mlir/test/python/dialects/pdl_types.py
index dfba2a36b8980..f75428d295c9c 100644
--- a/mlir/test/python/dialects/pdl_types.py
+++ b/mlir/test/python/dialects/pdl_types.py
@@ -5,149 +5,149 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
# CHECK-LABEL: TEST: test_attribute_type
@run
def test_attribute_type():
- with Context():
- parsedType = Type.parse("!pdl.attribute")
- constructedType = pdl.AttributeType.get()
+ with Context():
+ parsedType = Type.parse("!pdl.attribute")
+ constructedType = pdl.AttributeType.get()
- assert pdl.AttributeType.isinstance(parsedType)
- assert not pdl.OperationType.isinstance(parsedType)
- assert not pdl.RangeType.isinstance(parsedType)
- assert not pdl.TypeType.isinstance(parsedType)
- assert not pdl.ValueType.isinstance(parsedType)
+ assert pdl.AttributeType.isinstance(parsedType)
+ assert not pdl.OperationType.isinstance(parsedType)
+ assert not pdl.RangeType.isinstance(parsedType)
+ assert not pdl.TypeType.isinstance(parsedType)
+ assert not pdl.ValueType.isinstance(parsedType)
- assert pdl.AttributeType.isinstance(constructedType)
- assert not pdl.OperationType.isinstance(constructedType)
- assert not pdl.RangeType.isinstance(constructedType)
- assert not pdl.TypeType.isinstance(constructedType)
- assert not pdl.ValueType.isinstance(constructedType)
+ assert pdl.AttributeType.isinstance(constructedType)
+ assert not pdl.OperationType.isinstance(constructedType)
+ assert not pdl.RangeType.isinstance(constructedType)
+ assert not pdl.TypeType.isinstance(constructedType)
+ assert not pdl.ValueType.isinstance(constructedType)
- assert parsedType == constructedType
+ assert parsedType == constructedType
- # CHECK: !pdl.attribute
- print(parsedType)
- # CHECK: !pdl.attribute
- print(constructedType)
+ # CHECK: !pdl.attribute
+ print(parsedType)
+ # CHECK: !pdl.attribute
+ print(constructedType)
# CHECK-LABEL: TEST: test_operation_type
@run
def test_operation_type():
- with Context():
- parsedType = Type.parse("!pdl.operation")
- constructedType = pdl.OperationType.get()
+ with Context():
+ parsedType = Type.parse("!pdl.operation")
+ constructedType = pdl.OperationType.get()
- assert not pdl.AttributeType.isinstance(parsedType)
- assert pdl.OperationType.isinstance(parsedType)
- assert not pdl.RangeType.isinstance(parsedType)
- assert not pdl.TypeType.isinstance(parsedType)
- assert not pdl.ValueType.isinstance(parsedType)
+ assert not pdl.AttributeType.isinstance(parsedType)
+ assert pdl.OperationType.isinstance(parsedType)
+ assert not pdl.RangeType.isinstance(parsedType)
+ assert not pdl.TypeType.isinstance(parsedType)
+ assert not pdl.ValueType.isinstance(parsedType)
- assert not pdl.AttributeType.isinstance(constructedType)
- assert pdl.OperationType.isinstance(constructedType)
- assert not pdl.RangeType.isinstance(constructedType)
- assert not pdl.TypeType.isinstance(constructedType)
- assert not pdl.ValueType.isinstance(constructedType)
+ assert not pdl.AttributeType.isinstance(constructedType)
+ assert pdl.OperationType.isinstance(constructedType)
+ assert not pdl.RangeType.isinstance(constructedType)
+ assert not pdl.TypeType.isinstance(constructedType)
+ assert not pdl.ValueType.isinstance(constructedType)
- assert parsedType == constructedType
+ assert parsedType == constructedType
- # CHECK: !pdl.operation
- print(parsedType)
- # CHECK: !pdl.operation
- print(constructedType)
+ # CHECK: !pdl.operation
+ print(parsedType)
+ # CHECK: !pdl.operation
+ print(constructedType)
# CHECK-LABEL: TEST: test_range_type
@run
def test_range_type():
- with Context():
- typeType = Type.parse("!pdl.type")
- parsedType = Type.parse("!pdl.range<type>")
- constructedType = pdl.RangeType.get(typeType)
- elementType = constructedType.element_type
-
- assert not pdl.AttributeType.isinstance(parsedType)
- assert not pdl.OperationType.isinstance(parsedType)
- assert pdl.RangeType.isinstance(parsedType)
- assert not pdl.TypeType.isinstance(parsedType)
- assert not pdl.ValueType.isinstance(parsedType)
-
- assert not pdl.AttributeType.isinstance(constructedType)
- assert not pdl.OperationType.isinstance(constructedType)
- assert pdl.RangeType.isinstance(constructedType)
- assert not pdl.TypeType.isinstance(constructedType)
- assert not pdl.ValueType.isinstance(constructedType)
-
- assert parsedType == constructedType
- assert elementType == typeType
-
- # CHECK: !pdl.range<type>
- print(parsedType)
- # CHECK: !pdl.range<type>
- print(constructedType)
- # CHECK: !pdl.type
- print(elementType)
+ with Context():
+ typeType = Type.parse("!pdl.type")
+ parsedType = Type.parse("!pdl.range<type>")
+ constructedType = pdl.RangeType.get(typeType)
+ elementType = constructedType.element_type
+
+ assert not pdl.AttributeType.isinstance(parsedType)
+ assert not pdl.OperationType.isinstance(parsedType)
+ assert pdl.RangeType.isinstance(parsedType)
+ assert not pdl.TypeType.isinstance(parsedType)
+ assert not pdl.ValueType.isinstance(parsedType)
+
+ assert not pdl.AttributeType.isinstance(constructedType)
+ assert not pdl.OperationType.isinstance(constructedType)
+ assert pdl.RangeType.isinstance(constructedType)
+ assert not pdl.TypeType.isinstance(constructedType)
+ assert not pdl.ValueType.isinstance(constructedType)
+
+ assert parsedType == constructedType
+ assert elementType == typeType
+
+ # CHECK: !pdl.range<type>
+ print(parsedType)
+ # CHECK: !pdl.range<type>
+ print(constructedType)
+ # CHECK: !pdl.type
+ print(elementType)
# CHECK-LABEL: TEST: test_type_type
@run
def test_type_type():
- with Context():
- parsedType = Type.parse("!pdl.type")
- constructedType = pdl.TypeType.get()
+ with Context():
+ parsedType = Type.parse("!pdl.type")
+ constructedType = pdl.TypeType.get()
- assert not pdl.AttributeType.isinstance(parsedType)
- assert not pdl.OperationType.isinstance(parsedType)
- assert not pdl.RangeType.isinstance(parsedType)
- assert pdl.TypeType.isinstance(parsedType)
- assert not pdl.ValueType.isinstance(parsedType)
+ assert not pdl.AttributeType.isinstance(parsedType)
+ assert not pdl.OperationType.isinstance(parsedType)
+ assert not pdl.RangeType.isinstance(parsedType)
+ assert pdl.TypeType.isinstance(parsedType)
+ assert not pdl.ValueType.isinstance(parsedType)
- assert not pdl.AttributeType.isinstance(constructedType)
- assert not pdl.OperationType.isinstance(constructedType)
- assert not pdl.RangeType.isinstance(constructedType)
- assert pdl.TypeType.isinstance(constructedType)
- assert not pdl.ValueType.isinstance(constructedType)
+ assert not pdl.AttributeType.isinstance(constructedType)
+ assert not pdl.OperationType.isinstance(constructedType)
+ assert not pdl.RangeType.isinstance(constructedType)
+ assert pdl.TypeType.isinstance(constructedType)
+ assert not pdl.ValueType.isinstance(constructedType)
- assert parsedType == constructedType
+ assert parsedType == constructedType
- # CHECK: !pdl.type
- print(parsedType)
- # CHECK: !pdl.type
- print(constructedType)
+ # CHECK: !pdl.type
+ print(parsedType)
+ # CHECK: !pdl.type
+ print(constructedType)
# CHECK-LABEL: TEST: test_value_type
@run
def test_value_type():
- with Context():
- parsedType = Type.parse("!pdl.value")
- constructedType = pdl.ValueType.get()
+ with Context():
+ parsedType = Type.parse("!pdl.value")
+ constructedType = pdl.ValueType.get()
- assert not pdl.AttributeType.isinstance(parsedType)
- assert not pdl.OperationType.isinstance(parsedType)
- assert not pdl.RangeType.isinstance(parsedType)
- assert not pdl.TypeType.isinstance(parsedType)
- assert pdl.ValueType.isinstance(parsedType)
+ assert not pdl.AttributeType.isinstance(parsedType)
+ assert not pdl.OperationType.isinstance(parsedType)
+ assert not pdl.RangeType.isinstance(parsedType)
+ assert not pdl.TypeType.isinstance(parsedType)
+ assert pdl.ValueType.isinstance(parsedType)
- assert not pdl.AttributeType.isinstance(constructedType)
- assert not pdl.OperationType.isinstance(constructedType)
- assert not pdl.RangeType.isinstance(constructedType)
- assert not pdl.TypeType.isinstance(constructedType)
- assert pdl.ValueType.isinstance(constructedType)
+ assert not pdl.AttributeType.isinstance(constructedType)
+ assert not pdl.OperationType.isinstance(constructedType)
+ assert not pdl.RangeType.isinstance(constructedType)
+ assert not pdl.TypeType.isinstance(constructedType)
+ assert pdl.ValueType.isinstance(constructedType)
- assert parsedType == constructedType
+ assert parsedType == constructedType
- # CHECK: !pdl.value
- print(parsedType)
- # CHECK: !pdl.value
- print(constructedType)
+ # CHECK: !pdl.value
+ print(parsedType)
+ # CHECK: !pdl.value
+ print(constructedType)
# CHECK-LABEL: TEST: test_type_without_context
@@ -157,7 +157,10 @@ def test_type_without_context():
# should raise an exception but not crash.
try:
constructedType = pdl.ValueType.get()
- except TypeError:
- pass
+ except RuntimeError as e:
+ assert (
+ "An MLIR function requires a Context but none was provided in the call or from the surrounding environment"
+ in e.args[0]
+ )
else:
assert False, "Expected TypeError to be raised."
>From dd70894dc8bd302c45bcae9c23e4f033f4a68d1e Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 3 Jan 2026 18:55:39 -0800
Subject: [PATCH 3/3] update signatures
---
mlir/lib/Bindings/Python/DialectGPU.cpp | 40 +++---
mlir/lib/Bindings/Python/DialectLLVM.cpp | 62 +++++----
mlir/lib/Bindings/Python/DialectPDL.cpp | 6 +-
mlir/lib/Bindings/Python/DialectQuant.cpp | 68 ++++++----
mlir/lib/Bindings/Python/DialectSMT.cpp | 6 +-
.../Bindings/Python/DialectSparseTensor.cpp | 32 ++---
mlir/lib/Bindings/Python/DialectTransform.cpp | 4 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 124 ++++++------------
8 files changed, 163 insertions(+), 179 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 3ea8edec7b136..469fd524e8942 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -13,6 +13,8 @@
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include <mlir/Bindings/Python/IRAttributes.h>
+
namespace nb = nanobind;
using namespace nanobind::literals;
using namespace mlir::python::nanobind_adaptors;
@@ -54,9 +56,9 @@ struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](MlirAttribute target, uint32_t format, const nb::bytes &object,
- std::optional<MlirAttribute> mlirObjectProps,
- std::optional<MlirAttribute> mlirKernelsAttr,
+ [](const PyAttribute &target, uint32_t format, const nb::bytes &object,
+ std::optional<PyDictAttribute> mlirObjectProps,
+ std::optional<PyAttribute> mlirKernelsAttr,
DefaultingPyMlirContext context) {
MlirStringRef objectStrRef = mlirStringRefCreate(
static_cast<char *>(const_cast<void *>(object.data())),
@@ -74,26 +76,30 @@ struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
"kernels"_a = nb::none(), "context"_a = nb::none(),
"Gets a gpu.object from parameters.");
- c.def_prop_ro("target", [](MlirAttribute self) {
- return mlirGPUObjectAttrGetTarget(self);
+ c.def_prop_ro("target", [](ObjectAttr &self) {
+ return PyAttribute(self.getContext(), mlirGPUObjectAttrGetTarget(self));
});
- c.def_prop_ro("format", [](MlirAttribute self) {
+ c.def_prop_ro("format", [](const ObjectAttr &self) {
return mlirGPUObjectAttrGetFormat(self);
});
- c.def_prop_ro("object", [](MlirAttribute self) {
+ c.def_prop_ro("object", [](const ObjectAttr &self) {
MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
return nb::bytes(stringRef.data, stringRef.length);
});
- c.def_prop_ro("properties", [](MlirAttribute self) -> nb::object {
- if (mlirGPUObjectAttrHasProperties(self))
- return nb::cast(mlirGPUObjectAttrGetProperties(self));
- return nb::none();
- });
- c.def_prop_ro("kernels", [](MlirAttribute self) -> nb::object {
- if (mlirGPUObjectAttrHasKernels(self))
- return nb::cast(mlirGPUObjectAttrGetKernels(self));
- return nb::none();
- });
+ c.def_prop_ro(
+ "properties", [](ObjectAttr &self) -> std::optional<PyDictAttribute> {
+ if (mlirGPUObjectAttrHasProperties(self))
+ return PyDictAttribute(self.getContext(),
+ mlirGPUObjectAttrGetProperties(self));
+ return std::nullopt;
+ });
+ c.def_prop_ro("kernels",
+ [](ObjectAttr &self) -> std::optional<PyAttribute> {
+ if (mlirGPUObjectAttrHasKernels(self))
+ return PyAttribute(self.getContext(),
+ mlirGPUObjectAttrGetKernels(self));
+ return std::nullopt;
+ });
}
};
} // namespace gpu
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index d4eb078c0f55c..ff31398225a9c 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -42,13 +42,16 @@ struct StructType : PyConcreteType<StructType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get_literal",
- [](const std::vector<MlirType> &elements, bool packed, MlirLocation loc,
+ [](const std::vector<PyType> &elements, bool packed, MlirLocation loc,
DefaultingPyMlirContext context) {
python::CollectDiagnosticsToStringScope scope(
mlirLocationGetContext(loc));
+ std::vector<MlirType> elements_(elements.size());
+ std::transform(elements.begin(), elements.end(), elements_.begin(),
+ [](const PyType &elem) { return elem; });
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
- loc, elements.size(), elements.data(), packed);
+ loc, elements.size(), elements_.data(), packed);
if (mlirTypeIsNull(type)) {
throw nb::value_error(scope.takeMessage().c_str());
}
@@ -59,12 +62,16 @@ struct StructType : PyConcreteType<StructType> {
c.def_static(
"get_literal_unchecked",
- [](const std::vector<MlirType> &elements, bool packed,
+ [](const std::vector<PyType> &elements, bool packed,
DefaultingPyMlirContext context) {
python::CollectDiagnosticsToStringScope scope(context.get()->get());
+ std::vector<MlirType> elements_(elements.size());
+ std::transform(elements.begin(), elements.end(), elements_.begin(),
+ [](const PyType &elem) { return elem; });
+
MlirType type = mlirLLVMStructTypeLiteralGet(
- context.get()->get(), elements.size(), elements.data(), packed);
+ context.get()->get(), elements.size(), elements_.data(), packed);
if (mlirTypeIsNull(type)) {
throw nb::value_error(scope.takeMessage().c_str());
}
@@ -95,9 +102,13 @@ struct StructType : PyConcreteType<StructType> {
c.def(
"set_body",
- [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+ [](const StructType &self, const std::vector<PyType> &elements,
+ bool packed) {
+ std::vector<MlirType> elements_(elements.size());
+ std::transform(elements.begin(), elements.end(), elements_.begin(),
+ [](const PyType &elem) { return elem; });
MlirLogicalResult result = mlirLLVMStructTypeSetBody(
- self, elements.size(), elements.data(), packed);
+ self, elements.size(), elements_.data(), packed);
if (!mlirLogicalResultIsSuccess(result)) {
throw nb::value_error(
"Struct body already set to different content.");
@@ -107,26 +118,30 @@ struct StructType : PyConcreteType<StructType> {
c.def_static(
"new_identified",
- [](const std::string &name, const std::vector<MlirType> &elements,
+ [](const std::string &name, const std::vector<PyType> &elements,
bool packed, DefaultingPyMlirContext context) {
+ std::vector<MlirType> elements_(elements.size());
+ std::transform(elements.begin(), elements.end(), elements_.begin(),
+ [](const PyType &elem) { return elem; });
return StructType(context->getRef(),
mlirLLVMStructTypeIdentifiedNewGet(
context.get()->get(),
mlirStringRefCreate(name.data(), name.length()),
- elements.size(), elements.data(), packed));
+ elements.size(), elements_.data(), packed));
},
"name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
"context"_a = nb::none());
- c.def_prop_ro("name", [](PyType type) -> std::optional<std::string> {
- if (mlirLLVMStructTypeIsLiteral(type))
- return std::nullopt;
+ c.def_prop_ro(
+ "name", [](const StructType &type) -> std::optional<std::string> {
+ if (mlirLLVMStructTypeIsLiteral(type))
+ return std::nullopt;
- MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
- return StringRef(stringRef.data, stringRef.length).str();
- });
+ MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+ return StringRef(stringRef.data, stringRef.length).str();
+ });
- c.def_prop_ro("body", [](PyType type) -> nb::object {
+ c.def_prop_ro("body", [](const StructType &type) -> nb::object {
// Don't crash in absence of a body.
if (mlirLLVMStructTypeIsOpaque(type))
return nb::none();
@@ -139,11 +154,13 @@ struct StructType : PyConcreteType<StructType> {
return body;
});
- c.def_prop_ro("packed",
- [](PyType type) { return mlirLLVMStructTypeIsPacked(type); });
+ c.def_prop_ro("packed", [](const StructType &type) {
+ return mlirLLVMStructTypeIsPacked(type);
+ });
- c.def_prop_ro("opaque",
- [](PyType type) { return mlirLLVMStructTypeIsOpaque(type); });
+ c.def_prop_ro("opaque", [](const StructType &type) {
+ return mlirLLVMStructTypeIsOpaque(type);
+ });
}
};
@@ -174,7 +191,7 @@ struct PointerType : PyConcreteType<PointerType> {
},
"address_space"_a = nb::none(), nb::kw_only(),
"context"_a = nb::none());
- c.def_prop_ro("address_space", [](PyType type) {
+ c.def_prop_ro("address_space", [](const PointerType &type) {
return mlirLLVMPointerTypeGetAddressSpace(type);
});
}
@@ -186,12 +203,9 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
m.def(
"translate_module_to_llvmir",
- [](MlirOperation module) {
+ [](const PyOperation &module) {
return mlirTranslateModuleToLLVMIRToString(module);
},
- // clang-format off
- nb::sig("def translate_module_to_llvmir(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> str"),
- // clang-format on
"module"_a, nb::rv_policy::take_ownership);
}
} // namespace llvm
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index d2ed3b141d724..5bb51eb63ce56 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -87,7 +87,7 @@ struct RangeType : PyConcreteType<RangeType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](PyType &elementType, DefaultingPyMlirContext context) {
+ [](const PyType &elementType, DefaultingPyMlirContext context) {
return RangeType(context->getRef(), mlirPDLRangeTypeGet(elementType));
},
"Gets an instance of RangeType in the same context as the provided "
@@ -95,12 +95,10 @@ struct RangeType : PyConcreteType<RangeType> {
nb::arg("element_type"), nb::arg("context").none() = nb::none());
c.def_prop_ro(
"element_type",
- [](PyType &type) {
+ [](RangeType &type) {
return PyType(type.getContext(),
mlirPDLRangeTypeGetElementType(type));
},
- nb::sig(
- "def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")),
"Get the element type.");
}
};
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index a1e0a281a708d..3a9b8ffdf8971 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -14,6 +14,8 @@
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include <mlir/Bindings/Python/IRAttributes.h>
+
namespace nb = nanobind;
using namespace llvm;
using namespace mlir::python::nanobind_adaptors;
@@ -54,48 +56,52 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
nb::arg("is_signed"), nb::arg("integral_width"));
c.def_prop_ro(
"expressed_type",
- [](PyType type) {
+ [](QuantizedType &type) {
return PyType(type.getContext(),
mlirQuantizedTypeGetExpressedType(type));
},
"Type expressed by this quantized type.");
c.def_prop_ro(
"flags",
- [](const PyType &type) { return mlirQuantizedTypeGetFlags(type); },
+ [](const QuantizedType &type) {
+ return mlirQuantizedTypeGetFlags(type);
+ },
"Flags of this quantized type (named accessors should be preferred to "
"this)");
c.def_prop_ro(
"is_signed",
- [](const PyType &type) { return mlirQuantizedTypeIsSigned(type); },
+ [](const QuantizedType &type) {
+ return mlirQuantizedTypeIsSigned(type);
+ },
"Signedness of this quantized type.");
c.def_prop_ro(
"storage_type",
- [](PyType type) {
+ [](QuantizedType &type) {
return PyType(type.getContext(),
mlirQuantizedTypeGetStorageType(type));
},
"Storage type backing this quantized type.");
c.def_prop_ro(
"storage_type_min",
- [](const PyType &type) {
+ [](const QuantizedType &type) {
return mlirQuantizedTypeGetStorageTypeMin(type);
},
"The minimum value held by the storage type of this quantized type.");
c.def_prop_ro(
"storage_type_max",
- [](const PyType &type) {
+ [](const QuantizedType &type) {
return mlirQuantizedTypeGetStorageTypeMax(type);
},
"The maximum value held by the storage type of this quantized type.");
c.def_prop_ro(
"storage_type_integral_width",
- [](const PyType &type) {
+ [](const QuantizedType &type) {
return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
},
"The bitwidth of the storage type of this quantized type.");
c.def(
"is_compatible_expressed_type",
- [](const PyType &type, const PyType &candidate) {
+ [](const QuantizedType &type, const PyType &candidate) {
return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
},
"Checks whether the candidate type can be expressed by this quantized "
@@ -103,14 +109,14 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
nb::arg("candidate"));
c.def_prop_ro(
"quantized_element_type",
- [](PyType type) {
+ [](QuantizedType &type) {
return PyType(type.getContext(),
mlirQuantizedTypeGetQuantizedElementType(type));
},
"Element type of this quantized type expressed as quantized type.");
c.def(
"cast_from_storage_type",
- [](PyType type, const PyType &candidate) {
+ [](QuantizedType &type, const PyType &candidate) {
MlirType castResult =
mlirQuantizedTypeCastFromStorageType(type, candidate);
if (!mlirTypeIsNull(castResult))
@@ -125,10 +131,10 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
nb::arg("candidate"));
c.def_static(
"cast_to_storage_type",
- [](const PyType &type) {
+ [](QuantizedType &type) {
MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
if (!mlirTypeIsNull(castResult))
- return castResult;
+ return PyType(type.getContext(), castResult);
throw nb::type_error("Invalid cast.");
},
"Casts from a type based on a quantized type to a corresponding type "
@@ -137,7 +143,7 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
nb::arg("type"));
c.def(
"cast_from_expressed_type",
- [](PyType type, const PyType &candidate) {
+ [](QuantizedType &type, const PyType &candidate) {
MlirType castResult =
mlirQuantizedTypeCastFromExpressedType(type, candidate);
if (!mlirTypeIsNull(castResult))
@@ -151,10 +157,10 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
nb::arg("candidate"));
c.def_static(
"cast_to_expressed_type",
- [](const PyType &type) {
+ [](QuantizedType &type) {
MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
if (!mlirTypeIsNull(castResult))
- return castResult;
+ return PyType(type.getContext(), castResult);
throw nb::type_error("Invalid cast.");
},
"Casts from a type based on a quantized type to a corresponding type "
@@ -164,7 +170,7 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
nb::arg("type"));
c.def(
"cast_expressed_to_storage_type",
- [](PyType type, const PyType &candidate) {
+ [](QuantizedType &type, const PyType &candidate) {
MlirType castResult =
mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
if (!mlirTypeIsNull(castResult))
@@ -238,21 +244,21 @@ struct UniformQuantizedType
nb::arg("storage_type_max"), nb::arg("context") = nb::none());
c.def_prop_ro(
"scale",
- [](const PyType &type) {
+ [](const UniformQuantizedType &type) {
return mlirUniformQuantizedTypeGetScale(type);
},
"The scale designates the difference between the real values "
"corresponding to consecutive quantized values differing by 1.");
c.def_prop_ro(
"zero_point",
- [](const PyType &type) {
+ [](const UniformQuantizedType &type) {
return mlirUniformQuantizedTypeGetZeroPoint(type);
},
"The storage value corresponding to the real value 0 in the affine "
"equation.");
c.def_prop_ro(
"is_fixed_point",
- [](const PyType &type) {
+ [](const UniformQuantizedType &type) {
return mlirUniformQuantizedTypeIsFixedPoint(type);
},
"Fixed point values are real numbers divided by a scale.");
@@ -298,7 +304,7 @@ struct UniformQuantizedPerAxisType
nb::arg("storage_type_max"), nb::arg("context") = nb::none());
c.def_prop_ro(
"scales",
- [](const PyType &type) {
+ [](const UniformQuantizedPerAxisType &type) {
intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
std::vector<double> scales;
scales.reserve(nDim);
@@ -313,7 +319,7 @@ struct UniformQuantizedPerAxisType
"scale corresponds to the ith slice in the quantized_dimension.");
c.def_prop_ro(
"zero_points",
- [](const PyType &type) {
+ [](const UniformQuantizedPerAxisType &type) {
intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
std::vector<int64_t> zeroPoints;
zeroPoints.reserve(nDim);
@@ -329,14 +335,14 @@ struct UniformQuantizedPerAxisType
"quantized_dimension.");
c.def_prop_ro(
"quantized_dimension",
- [](const PyType &type) {
+ [](const UniformQuantizedPerAxisType &type) {
return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
},
"Specifies the dimension of the shape that the scales and zero points "
"correspond to.");
c.def_prop_ro(
"is_fixed_point",
- [](const PyType &type) {
+ [](const UniformQuantizedPerAxisType &type) {
return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
},
"Fixed point values are real numbers divided by a scale.");
@@ -379,7 +385,7 @@ struct UniformQuantizedSubChannelType
nb::arg("context") = nb::none());
c.def_prop_ro(
"quantized_dimensions",
- [](const PyType &type) {
+ [](const UniformQuantizedSubChannelType &type) {
intptr_t nDim =
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
std::vector<int32_t> quantizedDimensions;
@@ -399,7 +405,7 @@ struct UniformQuantizedSubChannelType
"i-th block size from block_sizes method.");
c.def_prop_ro(
"block_sizes",
- [](const PyType &type) {
+ [](const UniformQuantizedSubChannelType &type) {
intptr_t nDim =
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
std::vector<int64_t> blockSizes;
@@ -417,14 +423,18 @@ struct UniformQuantizedSubChannelType
"in the list returned by quantized_dimensions method.");
c.def_prop_ro(
"scales",
- [](const PyType &type) -> MlirAttribute {
- return mlirUniformQuantizedSubChannelTypeGetScales(type);
+ [](UniformQuantizedSubChannelType &type) {
+ return PyDenseElementsAttribute(
+ type.getContext(),
+ mlirUniformQuantizedSubChannelTypeGetScales(type));
},
"The scales of the quantized type.");
c.def_prop_ro(
"zero_points",
- [](const PyType &type) -> MlirAttribute {
- return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
+ [](UniformQuantizedSubChannelType &type) {
+ return PyDenseElementsAttribute(
+ type.getContext(),
+ mlirUniformQuantizedSubChannelTypeGetZeroPoints(type));
},
"The zero points of the quantized type.");
}
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 39490155d5216..2c12341b81439 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -97,7 +97,7 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) {
m.def(
"export_smtlib",
- [&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues,
+ [&exportSMTLIB](const PyOperation &module, bool inlineSingleUseValues,
bool indentLetBody) {
return exportSMTLIB(module, inlineSingleUseValues, indentLetBody);
},
@@ -105,9 +105,9 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) {
"indent_let_body"_a = false);
m.def(
"export_smtlib",
- [&exportSMTLIB](MlirModule module, bool inlineSingleUseValues,
+ [&exportSMTLIB](PyModule &module, bool inlineSingleUseValues,
bool indentLetBody) {
- return exportSMTLIB(mlirModuleGetOperation(module),
+ return exportSMTLIB(mlirModuleGetOperation(module.get()),
inlineSingleUseValues, indentLetBody);
},
"module"_a, "inline_single_use_values"_a = false,
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 6ec58dd88d24f..ca197ba32e074 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -52,10 +52,10 @@ struct EncodingAttr : PyConcreteAttribute<EncodingAttr> {
c.def_static(
"get",
[](std::vector<MlirSparseTensorLevelType> lvlTypes,
- std::optional<MlirAffineMap> dimToLvl,
- std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
- std::optional<MlirAttribute> explicitVal,
- std::optional<MlirAttribute> implicitVal,
+ std::optional<PyAffineMap> dimToLvl,
+ std::optional<PyAffineMap> lvlToDim, int posWidth, int crdWidth,
+ std::optional<PyAttribute> explicitVal,
+ std::optional<PyAttribute> implicitVal,
DefaultingPyMlirContext context) {
return EncodingAttr(
context->getRef(),
@@ -93,7 +93,7 @@ struct EncodingAttr : PyConcreteAttribute<EncodingAttr> {
nb::arg("n") = 0, nb::arg("m") = 0,
"Builds a sparse_tensor.encoding.level_type from parameters.");
- c.def_prop_ro("lvl_types", [](MlirAttribute self) {
+ c.def_prop_ro("lvl_types", [](const EncodingAttr &self) {
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
std::vector<MlirSparseTensorLevelType> ret;
ret.reserve(lvlRank);
@@ -103,53 +103,53 @@ struct EncodingAttr : PyConcreteAttribute<EncodingAttr> {
});
c.def_prop_ro(
- "dim_to_lvl", [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+ "dim_to_lvl", [](EncodingAttr &self) -> std::optional<PyAffineMap> {
MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
if (mlirAffineMapIsNull(ret))
return {};
- return ret;
+ return PyAffineMap(self.getContext(), ret);
});
c.def_prop_ro(
- "lvl_to_dim", [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+ "lvl_to_dim", [](EncodingAttr &self) -> std::optional<PyAffineMap> {
MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
if (mlirAffineMapIsNull(ret))
return {};
- return ret;
+ return PyAffineMap(self.getContext(), ret);
});
c.def_prop_ro("pos_width", mlirSparseTensorEncodingAttrGetPosWidth);
c.def_prop_ro("crd_width", mlirSparseTensorEncodingAttrGetCrdWidth);
c.def_prop_ro(
- "explicit_val", [](MlirAttribute self) -> std::optional<MlirAttribute> {
+ "explicit_val", [](EncodingAttr &self) -> std::optional<PyAttribute> {
MlirAttribute ret = mlirSparseTensorEncodingAttrGetExplicitVal(self);
if (mlirAttributeIsNull(ret))
return {};
- return ret;
+ return PyAttribute(self.getContext(), ret);
});
c.def_prop_ro(
- "implicit_val", [](MlirAttribute self) -> std::optional<MlirAttribute> {
+ "implicit_val", [](EncodingAttr &self) -> std::optional<PyAttribute> {
MlirAttribute ret = mlirSparseTensorEncodingAttrGetImplicitVal(self);
if (mlirAttributeIsNull(ret))
return {};
- return ret;
+ return PyAttribute(self.getContext(), ret);
});
- c.def_prop_ro("structured_n", [](MlirAttribute self) -> unsigned {
+ c.def_prop_ro("structured_n", [](const EncodingAttr &self) -> unsigned {
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
return mlirSparseTensorEncodingAttrGetStructuredN(
mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
});
- c.def_prop_ro("structured_m", [](MlirAttribute self) -> unsigned {
+ c.def_prop_ro("structured_m", [](const EncodingAttr &self) -> unsigned {
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
return mlirSparseTensorEncodingAttrGetStructuredM(
mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
});
- c.def_prop_ro("lvl_formats_enum", [](MlirAttribute self) {
+ c.def_prop_ro("lvl_formats_enum", [](const EncodingAttr &self) {
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
std::vector<PySparseTensorLevelFormat> ret;
ret.reserve(lvlRank);
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index f42ebd004d09f..19e6418f067bb 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -118,7 +118,7 @@ struct OperationType : PyConcreteType<OperationType> {
nb::arg("operation_name"), nb::arg("context").none() = nb::none());
c.def_prop_ro(
"operation_name",
- [](const PyType &type) {
+ [](const OperationType &type) {
MlirStringRef operationName =
mlirTransformOperationTypeGetOperationName(type);
return nb::str(operationName.data, operationName.length);
@@ -149,7 +149,7 @@ struct ParamType : PyConcreteType<ParamType> {
nb::arg("type"), nb::arg("context").none() = nb::none());
c.def_prop_ro(
"type",
- [](PyType type) {
+ [](ParamType type) {
return PyType(type.getContext(), mlirTransformParamTypeGetType(type));
},
"Get the type this ParamType is associated with.");
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index c282f4b6996e5..f04b9b7788dd0 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -55,7 +55,7 @@ class PyPatternRewriter {
mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
}
- void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+ void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); }
private:
MlirRewriterBase base;
@@ -342,38 +342,30 @@ void populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
- nb::
- class_<PyPatternRewriter>(m, "PatternRewriter")
- .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
- "The current insertion point of the PatternRewriter.")
- .def(
- "replace_op",
- [](PyPatternRewriter &self, MlirOperation op,
- MlirOperation newOp) { self.replaceOp(op, newOp); },
- "Replace an operation with a new operation.", nb::arg("op"),
- nb::arg("new_op"),
- // clang-format off
- nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
- // clang-format on
- )
- .def(
- "replace_op",
- [](PyPatternRewriter &self, MlirOperation op,
- const std::vector<MlirValue> &values) {
- self.replaceOp(op, values);
- },
- "Replace an operation with a list of values.", nb::arg("op"),
- nb::arg("values"),
- // clang-format off
- nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
- // clang-format on
- )
- .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
- nb::arg("op"),
- // clang-format off
- nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
- // clang-format on
- );
+ nb::class_<PyPatternRewriter>(m, "PatternRewriter")
+ .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+ "The current insertion point of the PatternRewriter.")
+ .def(
+ "replace_op",
+ [](PyPatternRewriter &self, PyOperationBase &op,
+ PyOperationBase &newOp) {
+ self.replaceOp(op.getOperation(), newOp.getOperation());
+ },
+ "Replace an operation with a new operation.", nb::arg("op"),
+ nb::arg("new_op"))
+ .def(
+ "replace_op",
+ [](PyPatternRewriter &self, PyOperationBase &op,
+ const std::vector<PyValue> &values) {
+ std::vector<MlirValue> values_(values.size());
+ std::transform(values.begin(), values.end(), values_.begin(),
+ [](const PyValue &val) { return val; });
+ self.replaceOp(op.getOperation(), values_);
+ },
+ "Replace an operation with a list of values.", nb::arg("op"),
+ nb::arg("values"))
+ .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+ nb::arg("op"));
//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet
@@ -428,42 +420,21 @@ void populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
- .def(
- "append",
- [](PyMlirPDLResultList results, const PyValue &value) {
- mlirPDLResultListPushBackValue(results, value);
- },
- // clang-format off
- nb::sig("def append(self, value: " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")")
- // clang-format on
- )
- .def(
- "append",
- [](PyMlirPDLResultList results, const PyOperation &op) {
- mlirPDLResultListPushBackOperation(results, op);
- },
- // clang-format off
- nb::sig("def append(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")")
- // clang-format on
- )
- .def(
- "append",
- [](PyMlirPDLResultList results, const PyType &type) {
- mlirPDLResultListPushBackType(results, type);
- },
- // clang-format off
- nb::sig("def append(self, type: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")")
- // clang-format on
- )
- .def(
- "append",
- [](PyMlirPDLResultList results, const PyAttribute &attr) {
- mlirPDLResultListPushBackAttribute(results, attr);
- },
- // clang-format off
- nb::sig("def append(self, attr: " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")")
- // clang-format on
- );
+ .def("append",
+ [](PyMlirPDLResultList results, const PyValue &value) {
+ mlirPDLResultListPushBackValue(results, value);
+ })
+ .def("append",
+ [](PyMlirPDLResultList results, const PyOperation &op) {
+ mlirPDLResultListPushBackOperation(results, op);
+ })
+ .def("append",
+ [](PyMlirPDLResultList results, const PyType &type) {
+ mlirPDLResultListPushBackType(results, type);
+ })
+ .def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) {
+ mlirPDLResultListPushBackAttribute(results, attr);
+ });
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
@@ -471,9 +442,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
new (&self) PyPDLPatternModule(
mlirPDLPatternModuleFromModule(module.get()));
},
- // clang-format off
- nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
- // clang-format on
"module"_a, "Create a PDL module from the given module.")
.def(
"__init__",
@@ -481,9 +449,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
new (&self) PyPDLPatternModule(
mlirPDLPatternModuleFromModule(module.get()));
},
- // clang-format off
- nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
- // clang-format on
"module"_a, "Create a PDL module from the given module.")
.def(
"freeze",
@@ -552,9 +517,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
throw std::runtime_error("pattern application failed to converge");
},
"module"_a, "set"_a,
- // clang-format off
- nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
- // clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
@@ -568,9 +530,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
"pattern application failed to converge");
},
"op"_a, "set"_a,
- // clang-format off
- nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
- // clang-format on
"Applys the given patterns to the given op greedily while folding "
"results.")
.def(
@@ -579,9 +538,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
mlirWalkAndApplyPatterns(op.getOperation(), set.get());
},
"op"_a, "set"_a,
- // clang-format off
- nb::sig("def walk_and_apply_patterns(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
- // clang-format on
"Applies the given patterns to the given op by a fast walk-based "
"driver.");
}
More information about the llvm-branch-commits
mailing list