[llvm-branch-commits] [mlir] [mlir][Python] port dialect extensions to use core PyConcreteType, PyConcreteAttribute (PR #174156)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 2 12:24:33 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
This PR ports all in-tree dialect extensions to use the `PyConcreteType`, `PyConcreteAttribute` CRTPs instead of `mlir_pure_subclass`. After this PR we can soft deprecate `mlir_pure_subclass`.
depends on https://github.com/llvm/llvm-project/pull/174118
---
Patch is 111.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174156.diff
11 Files Affected:
- (modified) mlir/lib/Bindings/Python/DialectAMDGPU.cpp (+74-36)
- (modified) mlir/lib/Bindings/Python/DialectGPU.cpp (+87-65)
- (modified) mlir/lib/Bindings/Python/DialectLLVM.cpp (+164-133)
- (modified) mlir/lib/Bindings/Python/DialectNVGPU.cpp (+31-18)
- (modified) mlir/lib/Bindings/Python/DialectPDL.cpp (+145-83)
- (modified) mlir/lib/Bindings/Python/DialectQuant.cpp (+454-355)
- (modified) mlir/lib/Bindings/Python/DialectSMT.cpp (+63-26)
- (modified) mlir/lib/Bindings/Python/DialectSparseTensor.cpp (+125-109)
- (modified) mlir/lib/Bindings/Python/DialectTransform.cpp (+150-98)
- (modified) mlir/python/mlir/dialects/transform/extras/__init__.py (+6-5)
- (modified) mlir/test/python/dialects/pdl_types.py (+107-104)
``````````diff
diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
index 26ffc0e427e41..26115c3635b7b 100644
--- a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
@@ -8,58 +8,96 @@
#include "mlir-c/Dialect/AMDGPU.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
- auto amdgpuTDMBaseType =
- mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType,
- mlirAMDGPUTDMBaseTypeGetTypeID);
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace amdgpu {
+struct TDMBaseType : PyConcreteType<TDMBaseType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAMDGPUTDMBaseType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMBaseTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMBaseType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMBaseType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType elementType, MlirContext ctx) {
- return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
- },
- "Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
- nb::arg("element_type"), nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &elementType, DefaultingPyMlirContext context) {
+ return TDMBaseType(
+ context->getRef(),
+ mlirAMDGPUTDMBaseTypeGet(context.get()->get(), elementType));
+ },
+ "Gets an instance of TDMBaseType in the same context",
+ nb::arg("element_type"), nb::arg("context").none() = nb::none());
+ }
+};
- auto amdgpuTDMDescriptorType = mlir_type_subclass(
- m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType,
- mlirAMDGPUTDMDescriptorTypeGetTypeID);
+struct TDMDescriptorType : PyConcreteType<TDMDescriptorType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAAMDGPUTDMDescriptorType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMDescriptorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMDescriptorType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMDescriptorType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
- },
- "Gets an instance of TDMDescriptorType in the same context",
- nb::arg("cls"), nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return TDMDescriptorType(
+ context->getRef(),
+ mlirAMDGPUTDMDescriptorTypeGet(context.get()->get()));
+ },
+ "Gets an instance of TDMDescriptorType in the same context",
+ nb::arg("context").none() = nb::none());
+ }
+};
- auto amdgpuTDMGatherBaseType = mlir_type_subclass(
- m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType,
- mlirAMDGPUTDMGatherBaseTypeGetTypeID);
+struct TDMGatherBaseType : PyConcreteType<TDMGatherBaseType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAAMDGPUTDMGatherBaseType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMGatherBaseTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMGatherBaseType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMGatherBaseType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType elementType, MlirType indexType,
- MlirContext ctx) {
- return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType));
- },
- "Gets an instance of TDMGatherBaseType in the same context",
- nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
- nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &elementType, const PyType &indexType,
+ DefaultingPyMlirContext context) {
+ return TDMGatherBaseType(
+ context->getRef(),
+ mlirAMDGPUTDMGatherBaseTypeGet(context.get()->get(), elementType,
+ indexType));
+ },
+ "Gets an instance of TDMGatherBaseType in the same context",
+ nb::arg("element_type"), nb::arg("index_type"),
+ nb::arg("context").none() = nb::none());
+ }
};
+static void populateDialectAMDGPUSubmodule(nb::module_ &m) {
+ TDMBaseType::bind(m);
+ TDMDescriptorType::bind(m);
+ TDMGatherBaseType::bind(m);
+}
+} // namespace amdgpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
NB_MODULE(_mlirDialectsAMDGPU, m) {
m.doc() = "MLIR AMDGPU dialect.";
- populateDialectAMDGPUSubmodule(m);
+ mlir::python::mlir::amdgpu::populateDialectAMDGPUSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 2568d535edb5a..ea3748cc88b85 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -9,83 +9,105 @@
#include "mlir-c/Dialect/GPU.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
-
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace gpu {
// -----------------------------------------------------------------------------
-// Module initialization.
+// AsyncTokenType
// -----------------------------------------------------------------------------
-NB_MODULE(_mlirDialectsGPU, m) {
- m.doc() = "MLIR GPU Dialect";
- //===-------------------------------------------------------------------===//
- // AsyncTokenType
- //===-------------------------------------------------------------------===//
+struct AsyncTokenType : PyConcreteType<AsyncTokenType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAGPUAsyncTokenType;
+ static constexpr const char *pyClassName = "AsyncTokenType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AsyncTokenType(context->getRef(),
+ mlirGPUAsyncTokenTypeGet(context.get()->get()));
+ },
+ "Gets an instance of AsyncTokenType in the same context",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// ObjectAttr
+//===-------------------------------------------------------------------===//
+
+struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAGPUObjectAttr;
+ static constexpr const char *pyClassName = "ObjectAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
- auto mlirGPUAsyncTokenType =
- mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](MlirAttribute target, uint32_t format, const nb::bytes &object,
+ std::optional<MlirAttribute> mlirObjectProps,
+ std::optional<MlirAttribute> mlirKernelsAttr,
+ DefaultingPyMlirContext context) {
+ MlirStringRef objectStrRef = mlirStringRefCreate(
+ static_cast<char *>(const_cast<void *>(object.data())),
+ object.size());
+ return ObjectAttr(
+ context->getRef(),
+ mlirGPUObjectAttrGetWithKernels(
+ mlirAttributeGetContext(target), target, format, objectStrRef,
+ mlirObjectProps.has_value() ? *mlirObjectProps
+ : MlirAttribute{nullptr},
+ mlirKernelsAttr.has_value() ? *mlirKernelsAttr
+ : MlirAttribute{nullptr}));
+ },
+ "target"_a, "format"_a, "object"_a, "properties"_a = nb::none(),
+ "kernels"_a = nb::none(), "context"_a = nb::none(),
+ "Gets a gpu.object from parameters.");
- mlirGPUAsyncTokenType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirGPUAsyncTokenTypeGet(ctx));
- },
- "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
- nb::arg("ctx") = nb::none());
+ c.def_prop_ro("target", [](MlirAttribute self) {
+ return mlirGPUObjectAttrGetTarget(self);
+ });
+ c.def_prop_ro("format", [](MlirAttribute self) {
+ return mlirGPUObjectAttrGetFormat(self);
+ });
+ c.def_prop_ro("object", [](MlirAttribute self) {
+ MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
+ return nb::bytes(stringRef.data, stringRef.length);
+ });
+ c.def_prop_ro("properties", [](MlirAttribute self) -> nb::object {
+ if (mlirGPUObjectAttrHasProperties(self))
+ return nb::cast(mlirGPUObjectAttrGetProperties(self));
+ return nb::none();
+ });
+ c.def_prop_ro("kernels", [](MlirAttribute self) -> nb::object {
+ if (mlirGPUObjectAttrHasKernels(self))
+ return nb::cast(mlirGPUObjectAttrGetKernels(self));
+ return nb::none();
+ });
+ }
+};
+} // namespace gpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
- //===-------------------------------------------------------------------===//
- // ObjectAttr
- //===-------------------------------------------------------------------===//
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
+NB_MODULE(_mlirDialectsGPU, m) {
+ m.doc() = "MLIR GPU Dialect";
- mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirAttribute target, uint32_t format,
- const nb::bytes &object,
- std::optional<MlirAttribute> mlirObjectProps,
- std::optional<MlirAttribute> mlirKernelsAttr) {
- MlirStringRef objectStrRef = mlirStringRefCreate(
- static_cast<char *>(const_cast<void *>(object.data())),
- object.size());
- return cls(mlirGPUObjectAttrGetWithKernels(
- mlirAttributeGetContext(target), target, format, objectStrRef,
- mlirObjectProps.has_value() ? *mlirObjectProps
- : MlirAttribute{nullptr},
- mlirKernelsAttr.has_value() ? *mlirKernelsAttr
- : MlirAttribute{nullptr}));
- },
- "cls"_a, "target"_a, "format"_a, "object"_a,
- "properties"_a = nb::none(), "kernels"_a = nb::none(),
- "Gets a gpu.object from parameters.")
- .def_property_readonly(
- "target",
- [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
- .def_property_readonly(
- "format",
- [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
- .def_property_readonly(
- "object",
- [](MlirAttribute self) {
- MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
- return nb::bytes(stringRef.data, stringRef.length);
- })
- .def_property_readonly("properties",
- [](MlirAttribute self) -> nb::object {
- if (mlirGPUObjectAttrHasProperties(self))
- return nb::cast(
- mlirGPUObjectAttrGetProperties(self));
- return nb::none();
- })
- .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
- if (mlirGPUObjectAttrHasKernels(self))
- return nb::cast(mlirGPUObjectAttrGetKernels(self));
- return nb::none();
- });
+ mlir::python::mlir::gpu::AsyncTokenType::bind(m);
+ mlir::python::mlir::gpu::ObjectAttr::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 05681cecf82b3..d4eb078c0f55c 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -13,149 +13,176 @@
#include "mlir-c/Support.h"
#include "mlir-c/Target/LLVMIR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
-
using namespace llvm;
using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
-
- //===--------------------------------------------------------------------===//
- // StructType
- //===--------------------------------------------------------------------===//
-
- auto llvmStructType = mlir_type_subclass(
- m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
-
- llvmStructType
- .def_classmethod(
- "get_literal",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirLocation loc) {
- CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
- MlirType type = mlirLLVMStructTypeLiteralGetChecked(
- loc, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "loc"_a = nb::none())
- .def_classmethod(
- "get_literal_unchecked",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirContext context) {
- CollectDiagnosticsToStringScope scope(context);
-
- MlirType type = mlirLLVMStructTypeLiteralGet(
- context, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "context"_a = nb::none());
-
- llvmStructType.def_classmethod(
- "get_identified",
- [](const nb::object &cls, const std::string &name, MlirContext context) {
- return cls(mlirLLVMStructTypeIdentifiedGet(
- context, mlirStringRefCreate(name.data(), name.size())));
- },
- "cls"_a, "name"_a, nb::kw_only(), "context"_a = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace llvm {
+//===--------------------------------------------------------------------===//
+// StructType
+//===--------------------------------------------------------------------===//
+
+struct StructType : PyConcreteType<StructType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMStructType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirLLVMStructTypeGetTypeID;
+ static constexpr const char *pyClassName = "StructType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_literal",
+ [](const std::vector<MlirType> &elements, bool packed, MlirLocation loc,
+ DefaultingPyMlirContext context) {
+ python::CollectDiagnosticsToStringScope scope(
+ mlirLocationGetContext(loc));
+
+ MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+ loc, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return StructType(context->getRef(), type);
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false, "loc"_a = nb::none(),
+ "context"_a = nb::none());
+
+ c.def_static(
+ "get_literal_unchecked",
+ [](const std::vector<MlirType> &elements, bool packed,
+ DefaultingPyMlirContext context) {
+ python::CollectDiagnosticsToStringScope scope(context.get()->get());
+
+ MlirType type = mlirLLVMStructTypeLiteralGet(
+ context.get()->get(), elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return StructType(context->getRef(), type);
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
+
+ c.def_static(
+ "get_identified",
+ [](const std::string &name, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeIdentifiedGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.size())));
+ },
+ "name"_a, nb::kw_only(), "context"_a = nb::none());
+
+ c.def_static(
+ "get_opaque",
+ [](const std::string &name, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeOpaqueGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.size())));
+ },
+ "name"_a, "context"_a = nb::none());
+
+ c.def(
+ "set_body",
+ [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+ MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+ self, elements.size(), elements.data(), packed);
+ if (!mlirLogicalResultIsSuccess(result)) {
+ throw nb::value_error(
+ "Struct body already set to different content.");
+ }
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false);
+
+ c.def_static(
+ "new_identified",
+ [](const std::string &name, const std::vector<MlirType> &elements,
+ bool packed, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeIdentifiedNewGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.length()),
+ elements.size(), elements.data(), packed));
+ },
+ "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
+
+ c.def_prop_ro("name", [](PyType type) -> std::optional<std::string> {
+ if (mlirLLVMStructTypeIsLiteral(type))
+ return std::nullopt;
+
+ MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+ return StringRef(stringRef.data, stringRef.length).str();
+ });
+
+ c.def_prop_ro("body", [](PyType type) -> nb::object {
+ // Don't crash in absence of a body.
+ if (mlirLLVMStructTypeIsOpaque(type))
+ return nb::none();
+
+ nb::list body;
+ for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type);
+ i < e; ++i) {
+ body.append(mlirLLVMStructTypeGetElementType(type, i));
+ }
+ return body;
+ });
+
+ c.def_prop_ro("packed",
+ [](PyType type) { return mlirLLVMStructTypeIsPacked(type); })...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/174156
More information about the llvm-branch-commits
mailing list