[llvm-branch-commits] [mlir] [mlir][Python] port in-tree dialect extensions to use MLIRPythonSupport (PR #174156)

Maksim Levental via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Jan 3 20:04:09 PST 2026


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/174156

>From bcb510a239ae586ee8c16a69d81283116192feb1 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 31 Dec 2025 14:20:39 -0800
Subject: [PATCH 1/3] [mlir][Python] move IRTypes and IRAttributes to public
 headers

---
 mlir/test/python/lib/PythonTestModuleNanobind.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index a296b5e814b4b..b229c02ccf5e6 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -156,4 +156,4 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
   PyTestType::bind(m);
   PyTestIntegerRankedTensorType::bind(m);
   PyTestTensorValue::bind(m);
-}
+}
\ No newline at end of file

>From b6af0195e1ab989efcd0c81b3a6ea21fb61eeccd Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 29 Dec 2025 11:14:00 -0800
Subject: [PATCH 2/3] [mlir][Python] port dialect extensions to use core
 PyConcreteType, PyConcreteAttribute

---
 mlir/lib/Bindings/Python/DialectAMDGPU.cpp    | 111 ++-
 mlir/lib/Bindings/Python/DialectGPU.cpp       | 152 ++--
 mlir/lib/Bindings/Python/DialectLLVM.cpp      | 297 ++++---
 mlir/lib/Bindings/Python/DialectNVGPU.cpp     |  50 +-
 mlir/lib/Bindings/Python/DialectPDL.cpp       | 228 +++--
 mlir/lib/Bindings/Python/DialectQuant.cpp     | 810 ++++++++++--------
 mlir/lib/Bindings/Python/DialectSMT.cpp       |  89 +-
 .../Bindings/Python/DialectSparseTensor.cpp   | 266 +++---
 mlir/lib/Bindings/Python/DialectTransform.cpp | 249 +++---
 .../dialects/transform/extras/__init__.py     |  11 +-
 mlir/test/python/dialects/pdl_types.py        | 211 ++---
 11 files changed, 1439 insertions(+), 1035 deletions(-)

diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
index 26ffc0e427e41..de24dfa9660c1 100644
--- a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
@@ -8,58 +8,97 @@
 
 #include "mlir-c/Dialect/AMDGPU.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "nanobind/nanobind.h"
 
 namespace nb = nanobind;
 using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
-  auto amdgpuTDMBaseType =
-      mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType,
-                         mlirAMDGPUTDMBaseTypeGetTypeID);
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace amdgpu {
+struct TDMBaseType : PyConcreteType<TDMBaseType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAMDGPUTDMBaseType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAMDGPUTDMBaseTypeGetTypeID;
+  static constexpr const char *pyClassName = "TDMBaseType";
+  using PyConcreteType::PyConcreteType;
 
-  amdgpuTDMBaseType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirType elementType, MlirContext ctx) {
-        return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
-      },
-      "Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
-      nb::arg("element_type"), nb::arg("ctx") = nb::none());
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const PyType &elementType, DefaultingPyMlirContext context) {
+          return TDMBaseType(
+              context->getRef(),
+              mlirAMDGPUTDMBaseTypeGet(context.get()->get(), elementType));
+        },
+        "Gets an instance of TDMBaseType in the same context",
+        nb::arg("element_type"), nb::arg("context").none() = nb::none());
+  }
+};
 
-  auto amdgpuTDMDescriptorType = mlir_type_subclass(
-      m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType,
-      mlirAMDGPUTDMDescriptorTypeGetTypeID);
+struct TDMDescriptorType : PyConcreteType<TDMDescriptorType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsAAMDGPUTDMDescriptorType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAMDGPUTDMDescriptorTypeGetTypeID;
+  static constexpr const char *pyClassName = "TDMDescriptorType";
+  using PyConcreteType::PyConcreteType;
 
-  amdgpuTDMDescriptorType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
-      },
-      "Gets an instance of TDMDescriptorType in the same context",
-      nb::arg("cls"), nb::arg("ctx") = nb::none());
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return TDMDescriptorType(
+              context->getRef(),
+              mlirAMDGPUTDMDescriptorTypeGet(context.get()->get()));
+        },
+        "Gets an instance of TDMDescriptorType in the same context",
+        nb::arg("context").none() = nb::none());
+  }
+};
 
-  auto amdgpuTDMGatherBaseType = mlir_type_subclass(
-      m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType,
-      mlirAMDGPUTDMGatherBaseTypeGetTypeID);
+struct TDMGatherBaseType : PyConcreteType<TDMGatherBaseType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsAAMDGPUTDMGatherBaseType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAMDGPUTDMGatherBaseTypeGetTypeID;
+  static constexpr const char *pyClassName = "TDMGatherBaseType";
+  using PyConcreteType::PyConcreteType;
 
-  amdgpuTDMGatherBaseType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirType elementType, MlirType indexType,
-         MlirContext ctx) {
-        return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType));
-      },
-      "Gets an instance of TDMGatherBaseType in the same context",
-      nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
-      nb::arg("ctx") = nb::none());
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const PyType &elementType, const PyType &indexType,
+           DefaultingPyMlirContext context) {
+          return TDMGatherBaseType(
+              context->getRef(),
+              mlirAMDGPUTDMGatherBaseTypeGet(context.get()->get(), elementType,
+                                             indexType));
+        },
+        "Gets an instance of TDMGatherBaseType in the same context",
+        nb::arg("element_type"), nb::arg("index_type"),
+        nb::arg("context").none() = nb::none());
+  }
 };
 
+static void populateDialectAMDGPUSubmodule(nb::module_ &m) {
+  TDMBaseType::bind(m);
+  TDMDescriptorType::bind(m);
+  TDMGatherBaseType::bind(m);
+}
+} // namespace amdgpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
 NB_MODULE(_mlirDialectsAMDGPU, m) {
   m.doc() = "MLIR AMDGPU dialect.";
 
-  populateDialectAMDGPUSubmodule(m);
+  mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::amdgpu::
+      populateDialectAMDGPUSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 2568d535edb5a..3ea8edec7b136 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -9,83 +9,105 @@
 #include "mlir-c/Dialect/GPU.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 using namespace nanobind::literals;
-
-using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace gpu {
 // -----------------------------------------------------------------------------
-// Module initialization.
+// AsyncTokenType
 // -----------------------------------------------------------------------------
 
-NB_MODULE(_mlirDialectsGPU, m) {
-  m.doc() = "MLIR GPU Dialect";
-  //===-------------------------------------------------------------------===//
-  // AsyncTokenType
-  //===-------------------------------------------------------------------===//
+struct AsyncTokenType : PyConcreteType<AsyncTokenType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAGPUAsyncTokenType;
+  static constexpr const char *pyClassName = "AsyncTokenType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return AsyncTokenType(context->getRef(),
+                                mlirGPUAsyncTokenTypeGet(context.get()->get()));
+        },
+        "Gets an instance of AsyncTokenType in the same context",
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// ObjectAttr
+//===-------------------------------------------------------------------===//
+
+struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAGPUObjectAttr;
+  static constexpr const char *pyClassName = "ObjectAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
 
-  auto mlirGPUAsyncTokenType =
-      mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](MlirAttribute target, uint32_t format, const nb::bytes &object,
+           std::optional<MlirAttribute> mlirObjectProps,
+           std::optional<MlirAttribute> mlirKernelsAttr,
+           DefaultingPyMlirContext context) {
+          MlirStringRef objectStrRef = mlirStringRefCreate(
+              static_cast<char *>(const_cast<void *>(object.data())),
+              object.size());
+          return ObjectAttr(
+              context->getRef(),
+              mlirGPUObjectAttrGetWithKernels(
+                  mlirAttributeGetContext(target), target, format, objectStrRef,
+                  mlirObjectProps.has_value() ? *mlirObjectProps
+                                              : MlirAttribute{nullptr},
+                  mlirKernelsAttr.has_value() ? *mlirKernelsAttr
+                                              : MlirAttribute{nullptr}));
+        },
+        "target"_a, "format"_a, "object"_a, "properties"_a = nb::none(),
+        "kernels"_a = nb::none(), "context"_a = nb::none(),
+        "Gets a gpu.object from parameters.");
 
-  mlirGPUAsyncTokenType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirGPUAsyncTokenTypeGet(ctx));
-      },
-      "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
-      nb::arg("ctx") = nb::none());
+    c.def_prop_ro("target", [](MlirAttribute self) {
+      return mlirGPUObjectAttrGetTarget(self);
+    });
+    c.def_prop_ro("format", [](MlirAttribute self) {
+      return mlirGPUObjectAttrGetFormat(self);
+    });
+    c.def_prop_ro("object", [](MlirAttribute self) {
+      MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
+      return nb::bytes(stringRef.data, stringRef.length);
+    });
+    c.def_prop_ro("properties", [](MlirAttribute self) -> nb::object {
+      if (mlirGPUObjectAttrHasProperties(self))
+        return nb::cast(mlirGPUObjectAttrGetProperties(self));
+      return nb::none();
+    });
+    c.def_prop_ro("kernels", [](MlirAttribute self) -> nb::object {
+      if (mlirGPUObjectAttrHasKernels(self))
+        return nb::cast(mlirGPUObjectAttrGetKernels(self));
+      return nb::none();
+    });
+  }
+};
+} // namespace gpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
-  //===-------------------------------------------------------------------===//
-  // ObjectAttr
-  //===-------------------------------------------------------------------===//
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
+NB_MODULE(_mlirDialectsGPU, m) {
+  m.doc() = "MLIR GPU Dialect";
 
-  mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
-      .def_classmethod(
-          "get",
-          [](const nb::object &cls, MlirAttribute target, uint32_t format,
-             const nb::bytes &object,
-             std::optional<MlirAttribute> mlirObjectProps,
-             std::optional<MlirAttribute> mlirKernelsAttr) {
-            MlirStringRef objectStrRef = mlirStringRefCreate(
-                static_cast<char *>(const_cast<void *>(object.data())),
-                object.size());
-            return cls(mlirGPUObjectAttrGetWithKernels(
-                mlirAttributeGetContext(target), target, format, objectStrRef,
-                mlirObjectProps.has_value() ? *mlirObjectProps
-                                            : MlirAttribute{nullptr},
-                mlirKernelsAttr.has_value() ? *mlirKernelsAttr
-                                            : MlirAttribute{nullptr}));
-          },
-          "cls"_a, "target"_a, "format"_a, "object"_a,
-          "properties"_a = nb::none(), "kernels"_a = nb::none(),
-          "Gets a gpu.object from parameters.")
-      .def_property_readonly(
-          "target",
-          [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
-      .def_property_readonly(
-          "format",
-          [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
-      .def_property_readonly(
-          "object",
-          [](MlirAttribute self) {
-            MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
-            return nb::bytes(stringRef.data, stringRef.length);
-          })
-      .def_property_readonly("properties",
-                             [](MlirAttribute self) -> nb::object {
-                               if (mlirGPUObjectAttrHasProperties(self))
-                                 return nb::cast(
-                                     mlirGPUObjectAttrGetProperties(self));
-                               return nb::none();
-                             })
-      .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
-        if (mlirGPUObjectAttrHasKernels(self))
-          return nb::cast(mlirGPUObjectAttrGetKernels(self));
-        return nb::none();
-      });
+  mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::gpu::AsyncTokenType::bind(m);
+  mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::gpu::ObjectAttr::bind(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 05681cecf82b3..d4eb078c0f55c 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -13,149 +13,176 @@
 #include "mlir-c/Support.h"
 #include "mlir-c/Target/LLVMIR.h"
 #include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 
 using namespace nanobind::literals;
-
 using namespace llvm;
 using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
-
-  //===--------------------------------------------------------------------===//
-  // StructType
-  //===--------------------------------------------------------------------===//
-
-  auto llvmStructType = mlir_type_subclass(
-      m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
-
-  llvmStructType
-      .def_classmethod(
-          "get_literal",
-          [](const nb::object &cls, const std::vector<MlirType> &elements,
-             bool packed, MlirLocation loc) {
-            CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
-            MlirType type = mlirLLVMStructTypeLiteralGetChecked(
-                loc, elements.size(), elements.data(), packed);
-            if (mlirTypeIsNull(type)) {
-              throw nb::value_error(scope.takeMessage().c_str());
-            }
-            return cls(type);
-          },
-          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-          "loc"_a = nb::none())
-      .def_classmethod(
-          "get_literal_unchecked",
-          [](const nb::object &cls, const std::vector<MlirType> &elements,
-             bool packed, MlirContext context) {
-            CollectDiagnosticsToStringScope scope(context);
-
-            MlirType type = mlirLLVMStructTypeLiteralGet(
-                context, elements.size(), elements.data(), packed);
-            if (mlirTypeIsNull(type)) {
-              throw nb::value_error(scope.takeMessage().c_str());
-            }
-            return cls(type);
-          },
-          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-          "context"_a = nb::none());
-
-  llvmStructType.def_classmethod(
-      "get_identified",
-      [](const nb::object &cls, const std::string &name, MlirContext context) {
-        return cls(mlirLLVMStructTypeIdentifiedGet(
-            context, mlirStringRefCreate(name.data(), name.size())));
-      },
-      "cls"_a, "name"_a, nb::kw_only(), "context"_a = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace llvm {
+//===--------------------------------------------------------------------===//
+// StructType
+//===--------------------------------------------------------------------===//
+
+struct StructType : PyConcreteType<StructType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMStructType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirLLVMStructTypeGetTypeID;
+  static constexpr const char *pyClassName = "StructType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_literal",
+        [](const std::vector<MlirType> &elements, bool packed, MlirLocation loc,
+           DefaultingPyMlirContext context) {
+          python::CollectDiagnosticsToStringScope scope(
+              mlirLocationGetContext(loc));
+
+          MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+              loc, elements.size(), elements.data(), packed);
+          if (mlirTypeIsNull(type)) {
+            throw nb::value_error(scope.takeMessage().c_str());
+          }
+          return StructType(context->getRef(), type);
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false, "loc"_a = nb::none(),
+        "context"_a = nb::none());
+
+    c.def_static(
+        "get_literal_unchecked",
+        [](const std::vector<MlirType> &elements, bool packed,
+           DefaultingPyMlirContext context) {
+          python::CollectDiagnosticsToStringScope scope(context.get()->get());
+
+          MlirType type = mlirLLVMStructTypeLiteralGet(
+              context.get()->get(), elements.size(), elements.data(), packed);
+          if (mlirTypeIsNull(type)) {
+            throw nb::value_error(scope.takeMessage().c_str());
+          }
+          return StructType(context->getRef(), type);
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false,
+        "context"_a = nb::none());
+
+    c.def_static(
+        "get_identified",
+        [](const std::string &name, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeIdentifiedGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), name.size())));
+        },
+        "name"_a, nb::kw_only(), "context"_a = nb::none());
+
+    c.def_static(
+        "get_opaque",
+        [](const std::string &name, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeOpaqueGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), name.size())));
+        },
+        "name"_a, "context"_a = nb::none());
+
+    c.def(
+        "set_body",
+        [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+          MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+              self, elements.size(), elements.data(), packed);
+          if (!mlirLogicalResultIsSuccess(result)) {
+            throw nb::value_error(
+                "Struct body already set to different content.");
+          }
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false);
+
+    c.def_static(
+        "new_identified",
+        [](const std::string &name, const std::vector<MlirType> &elements,
+           bool packed, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeIdentifiedNewGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), name.length()),
+                                elements.size(), elements.data(), packed));
+        },
+        "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+        "context"_a = nb::none());
+
+    c.def_prop_ro("name", [](PyType type) -> std::optional<std::string> {
+      if (mlirLLVMStructTypeIsLiteral(type))
+        return std::nullopt;
+
+      MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+      return StringRef(stringRef.data, stringRef.length).str();
+    });
+
+    c.def_prop_ro("body", [](PyType type) -> nb::object {
+      // Don't crash in absence of a body.
+      if (mlirLLVMStructTypeIsOpaque(type))
+        return nb::none();
+
+      nb::list body;
+      for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type);
+           i < e; ++i) {
+        body.append(mlirLLVMStructTypeGetElementType(type, i));
+      }
+      return body;
+    });
+
+    c.def_prop_ro("packed",
+                  [](PyType type) { return mlirLLVMStructTypeIsPacked(type); });
+
+    c.def_prop_ro("opaque",
+                  [](PyType type) { return mlirLLVMStructTypeIsOpaque(type); });
+  }
+};
+
+//===--------------------------------------------------------------------===//
+// PointerType
+//===--------------------------------------------------------------------===//
+
+struct PointerType : PyConcreteType<PointerType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMPointerType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirLLVMPointerTypeGetTypeID;
+  static constexpr const char *pyClassName = "PointerType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](std::optional<unsigned> addressSpace,
+           DefaultingPyMlirContext context) {
+          python::CollectDiagnosticsToStringScope scope(context.get()->get());
+          MlirType type = mlirLLVMPointerTypeGet(
+              context.get()->get(),
+              addressSpace.has_value() ? *addressSpace : 0);
+          if (mlirTypeIsNull(type)) {
+            throw nb::value_error(scope.takeMessage().c_str());
+          }
+          return PointerType(context->getRef(), type);
+        },
+        "address_space"_a = nb::none(), nb::kw_only(),
+        "context"_a = nb::none());
+    c.def_prop_ro("address_space", [](PyType type) {
+      return mlirLLVMPointerTypeGetAddressSpace(type);
+    });
+  }
+};
 
-  llvmStructType.def_classmethod(
-      "get_opaque",
-      [](const nb::object &cls, const std::string &name, MlirContext context) {
-        return cls(mlirLLVMStructTypeOpaqueGet(
-            context, mlirStringRefCreate(name.data(), name.size())));
-      },
-      "cls"_a, "name"_a, "context"_a = nb::none());
-
-  llvmStructType.def(
-      "set_body",
-      [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
-        MlirLogicalResult result = mlirLLVMStructTypeSetBody(
-            self, elements.size(), elements.data(), packed);
-        if (!mlirLogicalResultIsSuccess(result)) {
-          throw nb::value_error(
-              "Struct body already set to different content.");
-        }
-      },
-      "elements"_a, nb::kw_only(), "packed"_a = false);
-
-  llvmStructType.def_classmethod(
-      "new_identified",
-      [](const nb::object &cls, const std::string &name,
-         const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
-        return cls(mlirLLVMStructTypeIdentifiedNewGet(
-            ctx, mlirStringRefCreate(name.data(), name.length()),
-            elements.size(), elements.data(), packed));
-      },
-      "cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-      "context"_a = nb::none());
-
-  llvmStructType.def_property_readonly(
-      "name", [](MlirType type) -> std::optional<std::string> {
-        if (mlirLLVMStructTypeIsLiteral(type))
-          return std::nullopt;
-
-        MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
-        return StringRef(stringRef.data, stringRef.length).str();
-      });
-
-  llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object {
-    // Don't crash in absence of a body.
-    if (mlirLLVMStructTypeIsOpaque(type))
-      return nb::none();
-
-    nb::list body;
-    for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
-         ++i) {
-      body.append(mlirLLVMStructTypeGetElementType(type, i));
-    }
-    return body;
-  });
-
-  llvmStructType.def_property_readonly(
-      "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
-
-  llvmStructType.def_property_readonly(
-      "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
-
-  //===--------------------------------------------------------------------===//
-  // PointerType
-  //===--------------------------------------------------------------------===//
-
-  mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType,
-                     mlirLLVMPointerTypeGetTypeID)
-      .def_classmethod(
-          "get",
-          [](const nb::object &cls, std::optional<unsigned> addressSpace,
-             MlirContext context) {
-            CollectDiagnosticsToStringScope scope(context);
-            MlirType type = mlirLLVMPointerTypeGet(
-                context, addressSpace.has_value() ? *addressSpace : 0);
-            if (mlirTypeIsNull(type)) {
-              throw nb::value_error(scope.takeMessage().c_str());
-            }
-            return cls(type);
-          },
-          "cls"_a, "address_space"_a = nb::none(), nb::kw_only(),
-          "context"_a = nb::none())
-      .def_property_readonly("address_space", [](MlirType type) {
-        return mlirLLVMPointerTypeGetAddressSpace(type);
-      });
+static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
+  StructType::bind(m);
+  PointerType::bind(m);
 
   m.def(
       "translate_module_to_llvmir",
@@ -167,9 +194,13 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
       // clang-format on
       "module"_a, nb::rv_policy::take_ownership);
 }
+} // namespace llvm
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
 NB_MODULE(_mlirDialectsLLVM, m) {
   m.doc() = "MLIR LLVM Dialect";
 
-  populateDialectLLVMSubmodule(m);
+  python::mlir::llvm::populateDialectLLVMSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
index 18917416412c1..179cc32520e83 100644
--- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
@@ -8,34 +8,48 @@
 
 #include "mlir-c/Dialect/NVGPU.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectNVGPUSubmodule(const nb::module_ &m) {
-  auto nvgpuTensorMapDescriptorType = mlir_type_subclass(
-      m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType);
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace nvgpu {
+struct TensorMapDescriptorType : PyConcreteType<TensorMapDescriptorType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsANVGPUTensorMapDescriptorType;
+  static constexpr const char *pyClassName = "TensorMapDescriptorType";
+  using PyConcreteType::PyConcreteType;
 
-  nvgpuTensorMapDescriptorType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirType tensorMemrefType, int swizzle,
-         int l2promo, int oobFill, int interleave, MlirContext ctx) {
-        return cls(mlirNVGPUTensorMapDescriptorTypeGet(
-            ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
-      },
-      "Gets an instance of TensorMapDescriptorType in the same context",
-      nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"),
-      nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"),
-      nb::arg("ctx") = nb::none());
-}
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const PyType &tensorMemrefType, int swizzle, int l2promo,
+           int oobFill, int interleave, DefaultingPyMlirContext context) {
+          return TensorMapDescriptorType(
+              context->getRef(), mlirNVGPUTensorMapDescriptorTypeGet(
+                                     context.get()->get(), tensorMemrefType,
+                                     swizzle, l2promo, oobFill, interleave));
+        },
+        "Gets an instance of TensorMapDescriptorType in the same context",
+        nb::arg("tensor_type"), nb::arg("swizzle"), nb::arg("l2promo"),
+        nb::arg("oob_fill"), nb::arg("interleave"),
+        nb::arg("context").none() = nb::none());
+  }
+};
+} // namespace nvgpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
 NB_MODULE(_mlirDialectsNVGPU, m) {
   m.doc() = "MLIR NVGPU dialect.";
 
-  populateDialectNVGPUSubmodule(m);
+  mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::nvgpu::TensorMapDescriptorType::
+      bind(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index 1acb41080f711..d2ed3b141d724 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -8,98 +8,160 @@
 
 #include "mlir-c/Dialect/PDL.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectPDLSubmodule(const nanobind::module_ &m) {
-  //===-------------------------------------------------------------------===//
-  // PDLType
-  //===-------------------------------------------------------------------===//
-
-  auto pdlType = mlir_type_subclass(m, "PDLType", mlirTypeIsAPDLType);
-
-  //===-------------------------------------------------------------------===//
-  // AttributeType
-  //===-------------------------------------------------------------------===//
-
-  auto attributeType =
-      mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType);
-  attributeType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirPDLAttributeTypeGet(ctx));
-      },
-      "Get an instance of AttributeType in given context.", nb::arg("cls"),
-      nb::arg("context") = nb::none());
-
-  //===-------------------------------------------------------------------===//
-  // OperationType
-  //===-------------------------------------------------------------------===//
-
-  auto operationType =
-      mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType);
-  operationType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirPDLOperationTypeGet(ctx));
-      },
-      "Get an instance of OperationType in given context.", nb::arg("cls"),
-      nb::arg("context") = nb::none());
-
-  //===-------------------------------------------------------------------===//
-  // RangeType
-  //===-------------------------------------------------------------------===//
-
-  auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType);
-  rangeType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirType elementType) {
-        return cls(mlirPDLRangeTypeGet(elementType));
-      },
-      "Gets an instance of RangeType in the same context as the provided "
-      "element type.",
-      nb::arg("cls"), nb::arg("element_type"));
-  rangeType.def_property_readonly(
-      "element_type",
-      [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); },
-      nb::sig(
-          "def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")),
-      "Get the element type.");
-
-  //===-------------------------------------------------------------------===//
-  // TypeType
-  //===-------------------------------------------------------------------===//
-
-  auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType);
-  typeType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirPDLTypeTypeGet(ctx));
-      },
-      "Get an instance of TypeType in given context.", nb::arg("cls"),
-      nb::arg("context") = nb::none());
-
-  //===-------------------------------------------------------------------===//
-  // ValueType
-  //===-------------------------------------------------------------------===//
-
-  auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType);
-  valueType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirPDLValueTypeGet(ctx));
-      },
-      "Get an instance of TypeType in given context.", nb::arg("cls"),
-      nb::arg("context") = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace pdl {
+
+//===-------------------------------------------------------------------===//
+// PDLType
+//===-------------------------------------------------------------------===//
+
+struct PDLType : PyConcreteType<PDLType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLType;
+  static constexpr const char *pyClassName = "PDLType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {}
+};
+
+//===-------------------------------------------------------------------===//
+// AttributeType
+//===-------------------------------------------------------------------===//
+
+struct AttributeType : PyConcreteType<AttributeType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLAttributeType;
+  static constexpr const char *pyClassName = "AttributeType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return AttributeType(context->getRef(),
+                               mlirPDLAttributeTypeGet(context.get()->get()));
+        },
+        "Get an instance of AttributeType in given context.",
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// OperationType
+//===-------------------------------------------------------------------===//
+
+struct OperationType : PyConcreteType<OperationType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLOperationType;
+  static constexpr const char *pyClassName = "OperationType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return OperationType(context->getRef(),
+                               mlirPDLOperationTypeGet(context.get()->get()));
+        },
+        "Get an instance of OperationType in given context.",
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// RangeType
+//===-------------------------------------------------------------------===//
+
+struct RangeType : PyConcreteType<RangeType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLRangeType;
+  static constexpr const char *pyClassName = "RangeType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyType &elementType, DefaultingPyMlirContext context) {
+          return RangeType(context->getRef(), mlirPDLRangeTypeGet(elementType));
+        },
+        "Gets an instance of RangeType in the same context as the provided "
+        "element type.",
+        nb::arg("element_type"), nb::arg("context").none() = nb::none());
+    c.def_prop_ro(
+        "element_type",
+        [](PyType &type) {
+          return PyType(type.getContext(),
+                        mlirPDLRangeTypeGetElementType(type));
+        },
+        nb::sig(
+            "def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")),
+        "Get the element type.");
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// TypeType
+//===-------------------------------------------------------------------===//
+
+struct TypeType : PyConcreteType<TypeType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLTypeType;
+  static constexpr const char *pyClassName = "TypeType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return TypeType(context->getRef(),
+                          mlirPDLTypeTypeGet(context.get()->get()));
+        },
+        "Get an instance of TypeType in given context.",
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// ValueType
+//===-------------------------------------------------------------------===//
+
+struct ValueType : PyConcreteType<ValueType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLValueType;
+  static constexpr const char *pyClassName = "ValueType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return ValueType(context->getRef(),
+                           mlirPDLValueTypeGet(context.get()->get()));
+        },
+        "Get an instance of TypeType in given context.",
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+static void populateDialectPDLSubmodule(nanobind::module_ &m) {
+  PDLType::bind(m);
+  AttributeType::bind(m);
+  OperationType::bind(m);
+  RangeType::bind(m);
+  TypeType::bind(m);
+  ValueType::bind(m);
 }
+} // namespace pdl
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
 NB_MODULE(_mlirDialectsPDL, m) {
   m.doc() = "MLIR PDL dialect.";
-  populateDialectPDLSubmodule(m);
+  mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::pdl::populateDialectPDLSubmodule(
+      m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index a5220fcc00604..a1e0a281a708d 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -6,385 +6,485 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <cstdint>
 #include <vector>
 
-#include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Dialect/Quant.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 using namespace llvm;
-using namespace mlir;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectQuantSubmodule(const nb::module_ &m) {
-  //===-------------------------------------------------------------------===//
-  // QuantizedType
-  //===-------------------------------------------------------------------===//
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace quant {
+//===-------------------------------------------------------------------===//
+// QuantizedType
+//===-------------------------------------------------------------------===//
 
-  auto quantizedType =
-      mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
-  quantizedType.def_staticmethod(
-      "default_minimum_for_integer",
-      [](bool isSigned, unsigned integralWidth) {
-        return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
-                                                            integralWidth);
-      },
-      "Default minimum value for the integer with the specified signedness and "
-      "bit width.",
-      nb::arg("is_signed"), nb::arg("integral_width"));
-  quantizedType.def_staticmethod(
-      "default_maximum_for_integer",
-      [](bool isSigned, unsigned integralWidth) {
-        return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
-                                                            integralWidth);
-      },
-      "Default maximum value for the integer with the specified signedness and "
-      "bit width.",
-      nb::arg("is_signed"), nb::arg("integral_width"));
-  quantizedType.def_property_readonly(
-      "expressed_type",
-      [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
-      "Type expressed by this quantized type.");
-  quantizedType.def_property_readonly(
-      "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
-      "Flags of this quantized type (named accessors should be preferred to "
-      "this)");
-  quantizedType.def_property_readonly(
-      "is_signed",
-      [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
-      "Signedness of this quantized type.");
-  quantizedType.def_property_readonly(
-      "storage_type",
-      [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
-      "Storage type backing this quantized type.");
-  quantizedType.def_property_readonly(
-      "storage_type_min",
-      [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
-      "The minimum value held by the storage type of this quantized type.");
-  quantizedType.def_property_readonly(
-      "storage_type_max",
-      [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
-      "The maximum value held by the storage type of this quantized type.");
-  quantizedType.def_property_readonly(
-      "storage_type_integral_width",
-      [](MlirType type) {
-        return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
-      },
-      "The bitwidth of the storage type of this quantized type.");
-  quantizedType.def(
-      "is_compatible_expressed_type",
-      [](MlirType type, MlirType candidate) {
-        return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
-      },
-      "Checks whether the candidate type can be expressed by this quantized "
-      "type.",
-      nb::arg("candidate"));
-  quantizedType.def_property_readonly(
-      "quantized_element_type",
-      [](MlirType type) {
-        return mlirQuantizedTypeGetQuantizedElementType(type);
-      },
-      "Element type of this quantized type expressed as quantized type.");
-  quantizedType.def(
-      "cast_from_storage_type",
-      [](MlirType type, MlirType candidate) {
-        MlirType castResult =
-            mlirQuantizedTypeCastFromStorageType(type, candidate);
-        if (!mlirTypeIsNull(castResult))
-          return castResult;
-        throw nb::type_error("Invalid cast.");
-      },
-      "Casts from a type based on the storage type of this quantized type to a "
-      "corresponding type based on the quantized type. Raises TypeError if the "
-      "cast is not valid.",
-      nb::arg("candidate"));
-  quantizedType.def_staticmethod(
-      "cast_to_storage_type",
-      [](MlirType type) {
-        MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
-        if (!mlirTypeIsNull(castResult))
-          return castResult;
-        throw nb::type_error("Invalid cast.");
-      },
-      "Casts from a type based on a quantized type to a corresponding type "
-      "based on the storage type of this quantized type. Raises TypeError if "
-      "the cast is not valid.",
-      nb::arg("type"));
-  quantizedType.def(
-      "cast_from_expressed_type",
-      [](MlirType type, MlirType candidate) {
-        MlirType castResult =
-            mlirQuantizedTypeCastFromExpressedType(type, candidate);
-        if (!mlirTypeIsNull(castResult))
-          return castResult;
-        throw nb::type_error("Invalid cast.");
-      },
-      "Casts from a type based on the expressed type of this quantized type to "
-      "a corresponding type based on the quantized type. Raises TypeError if "
-      "the cast is not valid.",
-      nb::arg("candidate"));
-  quantizedType.def_staticmethod(
-      "cast_to_expressed_type",
-      [](MlirType type) {
-        MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
-        if (!mlirTypeIsNull(castResult))
-          return castResult;
-        throw nb::type_error("Invalid cast.");
-      },
-      "Casts from a type based on a quantized type to a corresponding type "
-      "based on the expressed type of this quantized type. Raises TypeError if "
-      "the cast is not valid.",
-      nb::arg("type"));
-  quantizedType.def(
-      "cast_expressed_to_storage_type",
-      [](MlirType type, MlirType candidate) {
-        MlirType castResult =
-            mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
-        if (!mlirTypeIsNull(castResult))
-          return castResult;
-        throw nb::type_error("Invalid cast.");
-      },
-      "Casts from a type based on the expressed type of this quantized type to "
-      "a corresponding type based on the storage type. Raises TypeError if the "
-      "cast is not valid.",
-      nb::arg("candidate"));
+struct QuantizedType : PyConcreteType<QuantizedType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAQuantizedType;
+  static constexpr const char *pyClassName = "QuantizedType";
+  using PyConcreteType::PyConcreteType;
 
-  quantizedType.get_class().attr("FLAG_SIGNED") =
-      mlirQuantizedTypeGetSignedFlag();
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "default_minimum_for_integer",
+        [](bool isSigned, unsigned integralWidth) {
+          return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
+                                                              integralWidth);
+        },
+        "Default minimum value for the integer with the specified signedness "
+        "and "
+        "bit width.",
+        nb::arg("is_signed"), nb::arg("integral_width"));
+    c.def_static(
+        "default_maximum_for_integer",
+        [](bool isSigned, unsigned integralWidth) {
+          return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
+                                                              integralWidth);
+        },
+        "Default maximum value for the integer with the specified signedness "
+        "and "
+        "bit width.",
+        nb::arg("is_signed"), nb::arg("integral_width"));
+    c.def_prop_ro(
+        "expressed_type",
+        [](PyType type) {
+          return PyType(type.getContext(),
+                        mlirQuantizedTypeGetExpressedType(type));
+        },
+        "Type expressed by this quantized type.");
+    c.def_prop_ro(
+        "flags",
+        [](const PyType &type) { return mlirQuantizedTypeGetFlags(type); },
+        "Flags of this quantized type (named accessors should be preferred to "
+        "this)");
+    c.def_prop_ro(
+        "is_signed",
+        [](const PyType &type) { return mlirQuantizedTypeIsSigned(type); },
+        "Signedness of this quantized type.");
+    c.def_prop_ro(
+        "storage_type",
+        [](PyType type) {
+          return PyType(type.getContext(),
+                        mlirQuantizedTypeGetStorageType(type));
+        },
+        "Storage type backing this quantized type.");
+    c.def_prop_ro(
+        "storage_type_min",
+        [](const PyType &type) {
+          return mlirQuantizedTypeGetStorageTypeMin(type);
+        },
+        "The minimum value held by the storage type of this quantized type.");
+    c.def_prop_ro(
+        "storage_type_max",
+        [](const PyType &type) {
+          return mlirQuantizedTypeGetStorageTypeMax(type);
+        },
+        "The maximum value held by the storage type of this quantized type.");
+    c.def_prop_ro(
+        "storage_type_integral_width",
+        [](const PyType &type) {
+          return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
+        },
+        "The bitwidth of the storage type of this quantized type.");
+    c.def(
+        "is_compatible_expressed_type",
+        [](const PyType &type, const PyType &candidate) {
+          return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
+        },
+        "Checks whether the candidate type can be expressed by this quantized "
+        "type.",
+        nb::arg("candidate"));
+    c.def_prop_ro(
+        "quantized_element_type",
+        [](PyType type) {
+          return PyType(type.getContext(),
+                        mlirQuantizedTypeGetQuantizedElementType(type));
+        },
+        "Element type of this quantized type expressed as quantized type.");
+    c.def(
+        "cast_from_storage_type",
+        [](PyType type, const PyType &candidate) {
+          MlirType castResult =
+              mlirQuantizedTypeCastFromStorageType(type, candidate);
+          if (!mlirTypeIsNull(castResult))
+            return QuantizedType(type.getContext(), castResult);
+          throw nb::type_error("Invalid cast.");
+        },
+        "Casts from a type based on the storage type of this quantized type to "
+        "a "
+        "corresponding type based on the quantized type. Raises TypeError if "
+        "the "
+        "cast is not valid.",
+        nb::arg("candidate"));
+    c.def_static(
+        "cast_to_storage_type",
+        [](const PyType &type) {
+          MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
+          if (!mlirTypeIsNull(castResult))
+            return castResult;
+          throw nb::type_error("Invalid cast.");
+        },
+        "Casts from a type based on a quantized type to a corresponding type "
+        "based on the storage type of this quantized type. Raises TypeError if "
+        "the cast is not valid.",
+        nb::arg("type"));
+    c.def(
+        "cast_from_expressed_type",
+        [](PyType type, const PyType &candidate) {
+          MlirType castResult =
+              mlirQuantizedTypeCastFromExpressedType(type, candidate);
+          if (!mlirTypeIsNull(castResult))
+            return PyType(type.getContext(), castResult);
+          throw nb::type_error("Invalid cast.");
+        },
+        "Casts from a type based on the expressed type of this quantized type "
+        "to "
+        "a corresponding type based on the quantized type. Raises TypeError if "
+        "the cast is not valid.",
+        nb::arg("candidate"));
+    c.def_static(
+        "cast_to_expressed_type",
+        [](const PyType &type) {
+          MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
+          if (!mlirTypeIsNull(castResult))
+            return castResult;
+          throw nb::type_error("Invalid cast.");
+        },
+        "Casts from a type based on a quantized type to a corresponding type "
+        "based on the expressed type of this quantized type. Raises TypeError "
+        "if "
+        "the cast is not valid.",
+        nb::arg("type"));
+    c.def(
+        "cast_expressed_to_storage_type",
+        [](PyType type, const PyType &candidate) {
+          MlirType castResult =
+              mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
+          if (!mlirTypeIsNull(castResult))
+            return PyType(type.getContext(), castResult);
+          throw nb::type_error("Invalid cast.");
+        },
+        "Casts from a type based on the expressed type of this quantized type "
+        "to "
+        "a corresponding type based on the storage type. Raises TypeError if "
+        "the "
+        "cast is not valid.",
+        nb::arg("candidate"));
+  }
+};
 
-  //===-------------------------------------------------------------------===//
-  // AnyQuantizedType
-  //===-------------------------------------------------------------------===//
+//===-------------------------------------------------------------------===//
+// AnyQuantizedType
+//===-------------------------------------------------------------------===//
 
-  auto anyQuantizedType =
-      mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
-                         quantizedType.get_class());
-  anyQuantizedType.def_classmethod(
-      "get",
-      [](const nb::object &cls, unsigned flags, MlirType storageType,
-         MlirType expressedType, int64_t storageTypeMin,
-         int64_t storageTypeMax) {
-        return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
-                                           storageTypeMin, storageTypeMax));
-      },
-      "Gets an instance of AnyQuantizedType in the same context as the "
-      "provided storage type.",
-      nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
-      nb::arg("expressed_type"), nb::arg("storage_type_min"),
-      nb::arg("storage_type_max"));
+struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAnyQuantizedType;
+  static constexpr const char *pyClassName = "AnyQuantizedType";
+  using PyConcreteType::PyConcreteType;
 
-  //===-------------------------------------------------------------------===//
-  // UniformQuantizedType
-  //===-------------------------------------------------------------------===//
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](unsigned flags, const PyType &storageType,
+           const PyType &expressedType, int64_t storageTypeMin,
+           int64_t storageTypeMax, DefaultingPyMlirContext context) {
+          return AnyQuantizedType(
+              context->getRef(),
+              mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
+                                      storageTypeMin, storageTypeMax));
+        },
+        "Gets an instance of AnyQuantizedType in the same context as the "
+        "provided storage type.",
+        nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+        nb::arg("storage_type_min"), nb::arg("storage_type_max"),
+        nb::arg("context") = nb::none());
+  }
+};
 
-  auto uniformQuantizedType = mlir_type_subclass(
-      m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
-      quantizedType.get_class());
-  uniformQuantizedType.def_classmethod(
-      "get",
-      [](const nb::object &cls, unsigned flags, MlirType storageType,
-         MlirType expressedType, double scale, int64_t zeroPoint,
-         int64_t storageTypeMin, int64_t storageTypeMax) {
-        return cls(mlirUniformQuantizedTypeGet(flags, storageType,
-                                               expressedType, scale, zeroPoint,
-                                               storageTypeMin, storageTypeMax));
-      },
-      "Gets an instance of UniformQuantizedType in the same context as the "
-      "provided storage type.",
-      nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
-      nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"),
-      nb::arg("storage_type_min"), nb::arg("storage_type_max"));
-  uniformQuantizedType.def_property_readonly(
-      "scale",
-      [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
-      "The scale designates the difference between the real values "
-      "corresponding to consecutive quantized values differing by 1.");
-  uniformQuantizedType.def_property_readonly(
-      "zero_point",
-      [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
-      "The storage value corresponding to the real value 0 in the affine "
-      "equation.");
-  uniformQuantizedType.def_property_readonly(
-      "is_fixed_point",
-      [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
-      "Fixed point values are real numbers divided by a scale.");
+//===-------------------------------------------------------------------===//
+// UniformQuantizedType
+//===-------------------------------------------------------------------===//
 
-  //===-------------------------------------------------------------------===//
-  // UniformQuantizedPerAxisType
-  //===-------------------------------------------------------------------===//
-  auto uniformQuantizedPerAxisType = mlir_type_subclass(
-      m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
-      quantizedType.get_class());
-  uniformQuantizedPerAxisType.def_classmethod(
-      "get",
-      [](const nb::object &cls, unsigned flags, MlirType storageType,
-         MlirType expressedType, std::vector<double> scales,
-         std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
-         int64_t storageTypeMin, int64_t storageTypeMax) {
-        if (scales.size() != zeroPoints.size())
-          throw nb::value_error(
-              "Mismatching number of scales and zero points.");
-        auto nDims = static_cast<intptr_t>(scales.size());
-        return cls(mlirUniformQuantizedPerAxisTypeGet(
-            flags, storageType, expressedType, nDims, scales.data(),
-            zeroPoints.data(), quantizedDimension, storageTypeMin,
-            storageTypeMax));
-      },
-      "Gets an instance of UniformQuantizedPerAxisType in the same context as "
-      "the provided storage type.",
-      nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
-      nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
-      nb::arg("quantized_dimension"), nb::arg("storage_type_min"),
-      nb::arg("storage_type_max"));
-  uniformQuantizedPerAxisType.def_property_readonly(
-      "scales",
-      [](MlirType type) {
-        intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
-        std::vector<double> scales;
-        scales.reserve(nDim);
-        for (intptr_t i = 0; i < nDim; ++i) {
-          double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
-          scales.push_back(scale);
-        }
-        return scales;
-      },
-      "The scales designate the difference between the real values "
-      "corresponding to consecutive quantized values differing by 1. The ith "
-      "scale corresponds to the ith slice in the quantized_dimension.");
-  uniformQuantizedPerAxisType.def_property_readonly(
-      "zero_points",
-      [](MlirType type) {
-        intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
-        std::vector<int64_t> zeroPoints;
-        zeroPoints.reserve(nDim);
-        for (intptr_t i = 0; i < nDim; ++i) {
-          int64_t zeroPoint =
-              mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
-          zeroPoints.push_back(zeroPoint);
-        }
-        return zeroPoints;
-      },
-      "the storage values corresponding to the real value 0 in the affine "
-      "equation. The ith zero point corresponds to the ith slice in the "
-      "quantized_dimension.");
-  uniformQuantizedPerAxisType.def_property_readonly(
-      "quantized_dimension",
-      [](MlirType type) {
-        return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
-      },
-      "Specifies the dimension of the shape that the scales and zero points "
-      "correspond to.");
-  uniformQuantizedPerAxisType.def_property_readonly(
-      "is_fixed_point",
-      [](MlirType type) {
-        return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
-      },
-      "Fixed point values are real numbers divided by a scale.");
+struct UniformQuantizedType
+    : PyConcreteType<UniformQuantizedType, QuantizedType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedType;
+  static constexpr const char *pyClassName = "UniformQuantizedType";
+  using PyConcreteType::PyConcreteType;
 
-  //===-------------------------------------------------------------------===//
-  // UniformQuantizedSubChannelType
-  //===-------------------------------------------------------------------===//
-  auto uniformQuantizedSubChannelType = mlir_type_subclass(
-      m, "UniformQuantizedSubChannelType",
-      mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
-  uniformQuantizedSubChannelType.def_classmethod(
-      "get",
-      [](const nb::object &cls, unsigned flags, MlirType storageType,
-         MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
-         std::vector<int32_t> quantizedDimensions,
-         std::vector<int64_t> blockSizes, int64_t storageTypeMin,
-         int64_t storageTypeMax) {
-        return cls(mlirUniformQuantizedSubChannelTypeGet(
-            flags, storageType, expressedType, scales, zeroPoints,
-            static_cast<intptr_t>(blockSizes.size()),
-            quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
-            storageTypeMax));
-      },
-      "Gets an instance of UniformQuantizedSubChannel in the same context as "
-      "the provided storage type.",
-      nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
-      nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
-      nb::arg("quantized_dimensions"), nb::arg("block_sizes"),
-      nb::arg("storage_type_min"), nb::arg("storage_type_max"));
-  uniformQuantizedSubChannelType.def_property_readonly(
-      "quantized_dimensions",
-      [](MlirType type) {
-        intptr_t nDim =
-            mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
-        std::vector<int32_t> quantizedDimensions;
-        quantizedDimensions.reserve(nDim);
-        for (intptr_t i = 0; i < nDim; ++i) {
-          quantizedDimensions.push_back(
-              mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i));
-        }
-        return quantizedDimensions;
-      },
-      "Gets the quantized dimensions. Each element in the returned list "
-      "represents an axis of the quantized data tensor that has a specified "
-      "block size. The order of elements corresponds to the order of block "
-      "sizes returned by 'block_sizes' method. It means that the data tensor "
-      "is quantized along the i-th dimension in the returned list using the "
-      "i-th block size from block_sizes method.");
-  uniformQuantizedSubChannelType.def_property_readonly(
-      "block_sizes",
-      [](MlirType type) {
-        intptr_t nDim =
-            mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
-        std::vector<int64_t> blockSizes;
-        blockSizes.reserve(nDim);
-        for (intptr_t i = 0; i < nDim; ++i) {
-          blockSizes.push_back(
-              mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
-        }
-        return blockSizes;
-      },
-      "Gets the block sizes for the quantized dimensions. The i-th element in "
-      "the returned list corresponds to the block size for the i-th dimension "
-      "in the list returned by quantized_dimensions method.");
-  uniformQuantizedSubChannelType.def_property_readonly(
-      "scales",
-      [](MlirType type) -> MlirAttribute {
-        return mlirUniformQuantizedSubChannelTypeGetScales(type);
-      },
-      "The scales of the quantized type.");
-  uniformQuantizedSubChannelType.def_property_readonly(
-      "zero_points",
-      [](MlirType type) -> MlirAttribute {
-        return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
-      },
-      "The zero points of the quantized type.");
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](unsigned flags, const PyType &storageType,
+           const PyType &expressedType, double scale, int64_t zeroPoint,
+           int64_t storageTypeMin, int64_t storageTypeMax,
+           DefaultingPyMlirContext context) {
+          return UniformQuantizedType(
+              context->getRef(),
+              mlirUniformQuantizedTypeGet(flags, storageType, expressedType,
+                                          scale, zeroPoint, storageTypeMin,
+                                          storageTypeMax));
+        },
+        "Gets an instance of UniformQuantizedType in the same context as the "
+        "provided storage type.",
+        nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+        nb::arg("scale"), nb::arg("zero_point"), nb::arg("storage_type_min"),
+        nb::arg("storage_type_max"), nb::arg("context") = nb::none());
+    c.def_prop_ro(
+        "scale",
+        [](const PyType &type) {
+          return mlirUniformQuantizedTypeGetScale(type);
+        },
+        "The scale designates the difference between the real values "
+        "corresponding to consecutive quantized values differing by 1.");
+    c.def_prop_ro(
+        "zero_point",
+        [](const PyType &type) {
+          return mlirUniformQuantizedTypeGetZeroPoint(type);
+        },
+        "The storage value corresponding to the real value 0 in the affine "
+        "equation.");
+    c.def_prop_ro(
+        "is_fixed_point",
+        [](const PyType &type) {
+          return mlirUniformQuantizedTypeIsFixedPoint(type);
+        },
+        "Fixed point values are real numbers divided by a scale.");
+  }
+};
 
-  //===-------------------------------------------------------------------===//
-  // CalibratedQuantizedType
-  //===-------------------------------------------------------------------===//
+//===-------------------------------------------------------------------===//
+// UniformQuantizedPerAxisType
+//===-------------------------------------------------------------------===//
 
-  auto calibratedQuantizedType = mlir_type_subclass(
-      m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
-      quantizedType.get_class());
-  calibratedQuantizedType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirType expressedType, double min,
-         double max) {
-        return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
-      },
-      "Gets an instance of CalibratedQuantizedType in the same context as the "
-      "provided expressed type.",
-      nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"),
-      nb::arg("max"));
-  calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
-    return mlirCalibratedQuantizedTypeGetMin(type);
-  });
-  calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
-    return mlirCalibratedQuantizedTypeGetMax(type);
-  });
+struct UniformQuantizedPerAxisType
+    : PyConcreteType<UniformQuantizedPerAxisType, QuantizedType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsAUniformQuantizedPerAxisType;
+  static constexpr const char *pyClassName = "UniformQuantizedPerAxisType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](unsigned flags, const PyType &storageType,
+           const PyType &expressedType, std::vector<double> scales,
+           std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
+           int64_t storageTypeMin, int64_t storageTypeMax,
+           DefaultingPyMlirContext context) {
+          if (scales.size() != zeroPoints.size())
+            throw nb::value_error(
+                "Mismatching number of scales and zero points.");
+          auto nDims = static_cast<intptr_t>(scales.size());
+          return UniformQuantizedPerAxisType(
+              context->getRef(),
+              mlirUniformQuantizedPerAxisTypeGet(
+                  flags, storageType, expressedType, nDims, scales.data(),
+                  zeroPoints.data(), quantizedDimension, storageTypeMin,
+                  storageTypeMax));
+        },
+        "Gets an instance of UniformQuantizedPerAxisType in the same context "
+        "as "
+        "the provided storage type.",
+        nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+        nb::arg("scales"), nb::arg("zero_points"),
+        nb::arg("quantized_dimension"), nb::arg("storage_type_min"),
+        nb::arg("storage_type_max"), nb::arg("context") = nb::none());
+    c.def_prop_ro(
+        "scales",
+        [](const PyType &type) {
+          intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
+          std::vector<double> scales;
+          scales.reserve(nDim);
+          for (intptr_t i = 0; i < nDim; ++i) {
+            double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
+            scales.push_back(scale);
+          }
+          return scales;
+        },
+        "The scales designate the difference between the real values "
+        "corresponding to consecutive quantized values differing by 1. The ith "
+        "scale corresponds to the ith slice in the quantized_dimension.");
+    c.def_prop_ro(
+        "zero_points",
+        [](const PyType &type) {
+          intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
+          std::vector<int64_t> zeroPoints;
+          zeroPoints.reserve(nDim);
+          for (intptr_t i = 0; i < nDim; ++i) {
+            int64_t zeroPoint =
+                mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
+            zeroPoints.push_back(zeroPoint);
+          }
+          return zeroPoints;
+        },
+        "the storage values corresponding to the real value 0 in the affine "
+        "equation. The ith zero point corresponds to the ith slice in the "
+        "quantized_dimension.");
+    c.def_prop_ro(
+        "quantized_dimension",
+        [](const PyType &type) {
+          return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
+        },
+        "Specifies the dimension of the shape that the scales and zero points "
+        "correspond to.");
+    c.def_prop_ro(
+        "is_fixed_point",
+        [](const PyType &type) {
+          return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
+        },
+        "Fixed point values are real numbers divided by a scale.");
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// UniformQuantizedSubChannelType
+//===-------------------------------------------------------------------===//
+
+struct UniformQuantizedSubChannelType
+    : PyConcreteType<UniformQuantizedSubChannelType, QuantizedType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsAUniformQuantizedSubChannelType;
+  static constexpr const char *pyClassName = "UniformQuantizedSubChannelType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](unsigned flags, const PyType &storageType,
+           const PyType &expressedType, MlirAttribute scales,
+           MlirAttribute zeroPoints, std::vector<int32_t> quantizedDimensions,
+           std::vector<int64_t> blockSizes, int64_t storageTypeMin,
+           int64_t storageTypeMax, DefaultingPyMlirContext context) {
+          return UniformQuantizedSubChannelType(
+              context->getRef(),
+              mlirUniformQuantizedSubChannelTypeGet(
+                  flags, storageType, expressedType, scales, zeroPoints,
+                  static_cast<intptr_t>(blockSizes.size()),
+                  quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
+                  storageTypeMax));
+        },
+        "Gets an instance of UniformQuantizedSubChannel in the same context as "
+        "the provided storage type.",
+        nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+        nb::arg("scales"), nb::arg("zero_points"),
+        nb::arg("quantized_dimensions"), nb::arg("block_sizes"),
+        nb::arg("storage_type_min"), nb::arg("storage_type_max"),
+        nb::arg("context") = nb::none());
+    c.def_prop_ro(
+        "quantized_dimensions",
+        [](const PyType &type) {
+          intptr_t nDim =
+              mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
+          std::vector<int32_t> quantizedDimensions;
+          quantizedDimensions.reserve(nDim);
+          for (intptr_t i = 0; i < nDim; ++i) {
+            quantizedDimensions.push_back(
+                mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type,
+                                                                        i));
+          }
+          return quantizedDimensions;
+        },
+        "Gets the quantized dimensions. Each element in the returned list "
+        "represents an axis of the quantized data tensor that has a specified "
+        "block size. The order of elements corresponds to the order of block "
+        "sizes returned by 'block_sizes' method. It means that the data tensor "
+        "is quantized along the i-th dimension in the returned list using the "
+        "i-th block size from block_sizes method.");
+    c.def_prop_ro(
+        "block_sizes",
+        [](const PyType &type) {
+          intptr_t nDim =
+              mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
+          std::vector<int64_t> blockSizes;
+          blockSizes.reserve(nDim);
+          for (intptr_t i = 0; i < nDim; ++i) {
+            blockSizes.push_back(
+                mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
+          }
+          return blockSizes;
+        },
+        "Gets the block sizes for the quantized dimensions. The i-th element "
+        "in "
+        "the returned list corresponds to the block size for the i-th "
+        "dimension "
+        "in the list returned by quantized_dimensions method.");
+    c.def_prop_ro(
+        "scales",
+        [](const PyType &type) -> MlirAttribute {
+          return mlirUniformQuantizedSubChannelTypeGetScales(type);
+        },
+        "The scales of the quantized type.");
+    c.def_prop_ro(
+        "zero_points",
+        [](const PyType &type) -> MlirAttribute {
+          return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
+        },
+        "The zero points of the quantized type.");
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// CalibratedQuantizedType
+//===-------------------------------------------------------------------===//
+
+struct CalibratedQuantizedType
+    : PyConcreteType<CalibratedQuantizedType, QuantizedType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsACalibratedQuantizedType;
+  static constexpr const char *pyClassName = "CalibratedQuantizedType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const PyType &expressedType, double min, double max,
+           DefaultingPyMlirContext context) {
+          return CalibratedQuantizedType(
+              context->getRef(),
+              mlirCalibratedQuantizedTypeGet(expressedType, min, max));
+        },
+        "Gets an instance of CalibratedQuantizedType in the same context as "
+        "the "
+        "provided expressed type.",
+        nb::arg("expressed_type"), nb::arg("min"), nb::arg("max"),
+        nb::arg("context") = nb::none());
+    c.def_prop_ro("min", [](const PyType &type) {
+      return mlirCalibratedQuantizedTypeGetMin(type);
+    });
+    c.def_prop_ro("max", [](const PyType &type) {
+      return mlirCalibratedQuantizedTypeGetMax(type);
+    });
+  }
+};
+
+static void populateDialectQuantSubmodule(nb::module_ &m) {
+  QuantizedType::bind(m);
+
+  // Set the FLAG_SIGNED class attribute after binding QuantizedType
+  auto quantizedTypeClass = m.attr("QuantizedType");
+  quantizedTypeClass.attr("FLAG_SIGNED") = mlirQuantizedTypeGetSignedFlag();
+
+  AnyQuantizedType::bind(m);
+  UniformQuantizedType::bind(m);
+  UniformQuantizedPerAxisType::bind(m);
+  UniformQuantizedSubChannelType::bind(m);
+  CalibratedQuantizedType::bind(m);
 }
+} // namespace quant
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
 NB_MODULE(_mlirDialectsQuant, m) {
   m.doc() = "MLIR Quantization dialect";
 
-  populateDialectQuantSubmodule(m);
+  mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::quant::
+      populateDialectQuantSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index a87918a05b126..39490155d5216 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -13,44 +13,77 @@
 #include "mlir-c/Support.h"
 #include "mlir-c/Target/ExportSMTLIB.h"
 #include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 
 using namespace nanobind::literals;
-
 using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectSMTSubmodule(nanobind::module_ &m) {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace smt {
+struct BoolType : PyConcreteType<BoolType> {
+  static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABool;
+  static constexpr const char *pyClassName = "BoolType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return BoolType(context->getRef(),
+                          mlirSMTTypeGetBool(context.get()->get()));
+        },
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+struct BitVectorType : PyConcreteType<BitVectorType> {
+  static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABitVector;
+  static constexpr const char *pyClassName = "BitVectorType";
+  using PyConcreteType::PyConcreteType;
 
-  auto smtBoolType =
-      mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
-          .def_staticmethod(
-              "get",
-              [](MlirContext context) { return mlirSMTTypeGetBool(context); },
-              "context"_a = nb::none());
-  auto smtBitVectorType =
-      mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
-          .def_staticmethod(
-              "get",
-              [](int32_t width, MlirContext context) {
-                return mlirSMTTypeGetBitVector(context, width);
-              },
-              "width"_a, "context"_a = nb::none());
-  auto smtIntType =
-      mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
-          .def_staticmethod(
-              "get",
-              [](MlirContext context) { return mlirSMTTypeGetInt(context); },
-              "context"_a = nb::none());
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](int32_t width, DefaultingPyMlirContext context) {
+          return BitVectorType(
+              context->getRef(),
+              mlirSMTTypeGetBitVector(context.get()->get(), width));
+        },
+        nb::arg("width"), nb::arg("context").none() = nb::none());
+  }
+};
+
+struct IntType : PyConcreteType<IntType> {
+  static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsAInt;
+  static constexpr const char *pyClassName = "IntType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return IntType(context->getRef(),
+                         mlirSMTTypeGetInt(context.get()->get()));
+        },
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+static void populateDialectSMTSubmodule(nanobind::module_ &m) {
+  BoolType::bind(m);
+  BitVectorType::bind(m);
+  IntType::bind(m);
 
   auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
                          bool indentLetBody) {
-    mlir::python::CollectDiagnosticsToStringScope scope(
-        mlirOperationGetContext(module));
+    CollectDiagnosticsToStringScope scope(mlirOperationGetContext(module));
     PyPrintAccumulator printAccum;
     MlirLogicalResult result = mlirTranslateOperationToSMTLIB(
         module, printAccum.getCallback(), printAccum.getUserData(),
@@ -80,9 +113,13 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) {
       "module"_a, "inline_single_use_values"_a = false,
       "indent_let_body"_a = false);
 }
+} // namespace smt
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
 NB_MODULE(_mlirDialectsSMT, m) {
   m.doc() = "MLIR SMT Dialect";
 
-  populateDialectSMTSubmodule(m);
+  python::mlir::smt::populateDialectSMTSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 00b65ee9745dc..6ec58dd88d24f 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -12,137 +12,179 @@
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/Dialect/SparseTensor.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 using namespace llvm;
-using namespace mlir;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectSparseTensorSubmodule(const nb::module_ &m) {
-  nb::enum_<MlirSparseTensorLevelFormat>(m, "LevelFormat", nb::is_arithmetic(),
-                                         nb::is_flag())
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace sparse_tensor {
+
+enum PySparseTensorLevelFormat : std::underlying_type_t<
+    MlirSparseTensorLevelFormat> {
+  MLIR_SPARSE_TENSOR_LEVEL_DENSE = MLIR_SPARSE_TENSOR_LEVEL_DENSE,
+  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = MLIR_SPARSE_TENSOR_LEVEL_SINGLETON,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED =
+      MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED
+};
+
+enum PySparseTensorLevelPropertyNondefault : std::underlying_type_t<
+    MlirSparseTensorLevelPropertyNondefault> {
+  MLIR_SPARSE_PROPERTY_NON_ORDERED = MLIR_SPARSE_PROPERTY_NON_ORDERED,
+  MLIR_SPARSE_PROPERTY_NON_UNIQUE = MLIR_SPARSE_PROPERTY_NON_UNIQUE,
+  MLIR_SPARSE_PROPERTY_SOA = MLIR_SPARSE_PROPERTY_SOA,
+};
+
+struct EncodingAttr : PyConcreteAttribute<EncodingAttr> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirAttributeIsASparseTensorEncodingAttr;
+  static constexpr const char *pyClassName = "EncodingAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](std::vector<MlirSparseTensorLevelType> lvlTypes,
+           std::optional<MlirAffineMap> dimToLvl,
+           std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
+           std::optional<MlirAttribute> explicitVal,
+           std::optional<MlirAttribute> implicitVal,
+           DefaultingPyMlirContext context) {
+          return EncodingAttr(
+              context->getRef(),
+              mlirSparseTensorEncodingAttrGet(
+                  context.get()->get(), lvlTypes.size(), lvlTypes.data(),
+                  dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
+                  lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
+                  crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
+                  implicitVal ? *implicitVal : MlirAttribute{nullptr}));
+        },
+        nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(),
+        nb::arg("lvl_to_dim").none(), nb::arg("pos_width"),
+        nb::arg("crd_width"), nb::arg("explicit_val") = nb::none(),
+        nb::arg("implicit_val") = nb::none(), nb::arg("context") = nb::none(),
+        "Gets a sparse_tensor.encoding from parameters.");
+
+    c.def_static(
+        "build_level_type",
+        [](PySparseTensorLevelFormat lvlFmt,
+           const std::vector<PySparseTensorLevelPropertyNondefault> &properties,
+           unsigned n, unsigned m) {
+          std::vector<MlirSparseTensorLevelPropertyNondefault> props;
+          props.reserve(properties.size());
+          for (auto prop : properties) {
+            props.push_back(
+                static_cast<MlirSparseTensorLevelPropertyNondefault>(prop));
+          }
+          return mlirSparseTensorEncodingAttrBuildLvlType(
+              static_cast<MlirSparseTensorLevelFormat>(lvlFmt), props.data(),
+              props.size(), n, m);
+        },
+        nb::arg("lvl_fmt"),
+        nb::arg("properties") =
+            std::vector<PySparseTensorLevelPropertyNondefault>(),
+        nb::arg("n") = 0, nb::arg("m") = 0,
+        "Builds a sparse_tensor.encoding.level_type from parameters.");
+
+    c.def_prop_ro("lvl_types", [](MlirAttribute self) {
+      const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+      std::vector<MlirSparseTensorLevelType> ret;
+      ret.reserve(lvlRank);
+      for (int l = 0; l < lvlRank; ++l)
+        ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
+      return ret;
+    });
+
+    c.def_prop_ro(
+        "dim_to_lvl", [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+          MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
+          if (mlirAffineMapIsNull(ret))
+            return {};
+          return ret;
+        });
+
+    c.def_prop_ro(
+        "lvl_to_dim", [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+          MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
+          if (mlirAffineMapIsNull(ret))
+            return {};
+          return ret;
+        });
+
+    c.def_prop_ro("pos_width", mlirSparseTensorEncodingAttrGetPosWidth);
+    c.def_prop_ro("crd_width", mlirSparseTensorEncodingAttrGetCrdWidth);
+
+    c.def_prop_ro(
+        "explicit_val", [](MlirAttribute self) -> std::optional<MlirAttribute> {
+          MlirAttribute ret = mlirSparseTensorEncodingAttrGetExplicitVal(self);
+          if (mlirAttributeIsNull(ret))
+            return {};
+          return ret;
+        });
+
+    c.def_prop_ro(
+        "implicit_val", [](MlirAttribute self) -> std::optional<MlirAttribute> {
+          MlirAttribute ret = mlirSparseTensorEncodingAttrGetImplicitVal(self);
+          if (mlirAttributeIsNull(ret))
+            return {};
+          return ret;
+        });
+
+    c.def_prop_ro("structured_n", [](MlirAttribute self) -> unsigned {
+      const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+      return mlirSparseTensorEncodingAttrGetStructuredN(
+          mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
+    });
+
+    c.def_prop_ro("structured_m", [](MlirAttribute self) -> unsigned {
+      const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+      return mlirSparseTensorEncodingAttrGetStructuredM(
+          mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
+    });
+
+    c.def_prop_ro("lvl_formats_enum", [](MlirAttribute self) {
+      const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+      std::vector<PySparseTensorLevelFormat> ret;
+      ret.reserve(lvlRank);
+
+      for (int l = 0; l < lvlRank; l++)
+        ret.push_back(static_cast<PySparseTensorLevelFormat>(
+            mlirSparseTensorEncodingAttrGetLvlFmt(self, l)));
+      return ret;
+    });
+  }
+};
+
+static void populateDialectSparseTensorSubmodule(nb::module_ &m) {
+  nb::enum_<PySparseTensorLevelFormat>(m, "LevelFormat", nb::is_arithmetic(),
+                                       nb::is_flag())
       .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
       .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
       .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
       .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON)
       .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED);
 
-  nb::enum_<MlirSparseTensorLevelPropertyNondefault>(m, "LevelProperty")
+  nb::enum_<PySparseTensorLevelPropertyNondefault>(m, "LevelProperty")
       .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED)
       .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE)
       .value("soa", MLIR_SPARSE_PROPERTY_SOA);
 
-  mlir_attribute_subclass(m, "EncodingAttr",
-                          mlirAttributeIsASparseTensorEncodingAttr)
-      .def_classmethod(
-          "get",
-          [](const nb::object &cls,
-             std::vector<MlirSparseTensorLevelType> lvlTypes,
-             std::optional<MlirAffineMap> dimToLvl,
-             std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
-             std::optional<MlirAttribute> explicitVal,
-             std::optional<MlirAttribute> implicitVal, MlirContext context) {
-            return cls(mlirSparseTensorEncodingAttrGet(
-                context, lvlTypes.size(), lvlTypes.data(),
-                dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
-                lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
-                crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
-                implicitVal ? *implicitVal : MlirAttribute{nullptr}));
-          },
-          nb::arg("cls"), nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(),
-          nb::arg("lvl_to_dim").none(), nb::arg("pos_width"),
-          nb::arg("crd_width"), nb::arg("explicit_val") = nb::none(),
-          nb::arg("implicit_val") = nb::none(), nb::arg("context") = nb::none(),
-          "Gets a sparse_tensor.encoding from parameters.")
-      .def_classmethod(
-          "build_level_type",
-          [](const nb::object &cls, MlirSparseTensorLevelFormat lvlFmt,
-             const std::vector<MlirSparseTensorLevelPropertyNondefault>
-                 &properties,
-             unsigned n, unsigned m) {
-            return mlirSparseTensorEncodingAttrBuildLvlType(
-                lvlFmt, properties.data(), properties.size(), n, m);
-          },
-          nb::arg("cls"), nb::arg("lvl_fmt"),
-          nb::arg("properties") =
-              std::vector<MlirSparseTensorLevelPropertyNondefault>(),
-          nb::arg("n") = 0, nb::arg("m") = 0,
-          "Builds a sparse_tensor.encoding.level_type from parameters.")
-      .def_property_readonly(
-          "lvl_types",
-          [](MlirAttribute self) {
-            const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
-            std::vector<MlirSparseTensorLevelType> ret;
-            ret.reserve(lvlRank);
-            for (int l = 0; l < lvlRank; ++l)
-              ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
-            return ret;
-          })
-      .def_property_readonly(
-          "dim_to_lvl",
-          [](MlirAttribute self) -> std::optional<MlirAffineMap> {
-            MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
-            if (mlirAffineMapIsNull(ret))
-              return {};
-            return ret;
-          })
-      .def_property_readonly(
-          "lvl_to_dim",
-          [](MlirAttribute self) -> std::optional<MlirAffineMap> {
-            MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
-            if (mlirAffineMapIsNull(ret))
-              return {};
-            return ret;
-          })
-      .def_property_readonly("pos_width",
-                             mlirSparseTensorEncodingAttrGetPosWidth)
-      .def_property_readonly("crd_width",
-                             mlirSparseTensorEncodingAttrGetCrdWidth)
-      .def_property_readonly(
-          "explicit_val",
-          [](MlirAttribute self) -> std::optional<MlirAttribute> {
-            MlirAttribute ret =
-                mlirSparseTensorEncodingAttrGetExplicitVal(self);
-            if (mlirAttributeIsNull(ret))
-              return {};
-            return ret;
-          })
-      .def_property_readonly(
-          "implicit_val",
-          [](MlirAttribute self) -> std::optional<MlirAttribute> {
-            MlirAttribute ret =
-                mlirSparseTensorEncodingAttrGetImplicitVal(self);
-            if (mlirAttributeIsNull(ret))
-              return {};
-            return ret;
-          })
-      .def_property_readonly(
-          "structured_n",
-          [](MlirAttribute self) -> unsigned {
-            const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
-            return mlirSparseTensorEncodingAttrGetStructuredN(
-                mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
-          })
-      .def_property_readonly(
-          "structured_m",
-          [](MlirAttribute self) -> unsigned {
-            const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
-            return mlirSparseTensorEncodingAttrGetStructuredM(
-                mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
-          })
-      .def_property_readonly("lvl_formats_enum", [](MlirAttribute self) {
-        const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
-        std::vector<MlirSparseTensorLevelFormat> ret;
-        ret.reserve(lvlRank);
-        for (int l = 0; l < lvlRank; l++)
-          ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l));
-        return ret;
-      });
+  EncodingAttr::bind(m);
 }
+} // namespace sparse_tensor
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
 NB_MODULE(_mlirDialectsSparseTensor, m) {
   m.doc() = "MLIR SparseTensor dialect.";
-  populateDialectSparseTensorSubmodule(m);
+  mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::sparse_tensor::
+      populateDialectSparseTensorSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 150c69953d960..f42ebd004d09f 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -11,112 +11,165 @@
 #include "mlir-c/Dialect/Transform.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
-using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectTransformSubmodule(const nb::module_ &m) {
-  //===-------------------------------------------------------------------===//
-  // AnyOpType
-  //===-------------------------------------------------------------------===//
-
-  auto anyOpType =
-      mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
-                         mlirTransformAnyOpTypeGetTypeID);
-  anyOpType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirTransformAnyOpTypeGet(ctx));
-      },
-      "Get an instance of AnyOpType in the given context.", nb::arg("cls"),
-      nb::arg("context") = nb::none());
-
-  //===-------------------------------------------------------------------===//
-  // AnyParamType
-  //===-------------------------------------------------------------------===//
-
-  auto anyParamType =
-      mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
-                         mlirTransformAnyParamTypeGetTypeID);
-  anyParamType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirTransformAnyParamTypeGet(ctx));
-      },
-      "Get an instance of AnyParamType in the given context.", nb::arg("cls"),
-      nb::arg("context") = nb::none());
-
-  //===-------------------------------------------------------------------===//
-  // AnyValueType
-  //===-------------------------------------------------------------------===//
-
-  auto anyValueType =
-      mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
-                         mlirTransformAnyValueTypeGetTypeID);
-  anyValueType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirTransformAnyValueTypeGet(ctx));
-      },
-      "Get an instance of AnyValueType in the given context.", nb::arg("cls"),
-      nb::arg("context") = nb::none());
-
-  //===-------------------------------------------------------------------===//
-  // OperationType
-  //===-------------------------------------------------------------------===//
-
-  auto operationType =
-      mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
-                         mlirTransformOperationTypeGetTypeID);
-  operationType.def_classmethod(
-      "get",
-      [](const nb::object &cls, const std::string &operationName,
-         MlirContext ctx) {
-        MlirStringRef cOperationName =
-            mlirStringRefCreate(operationName.data(), operationName.size());
-        return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
-      },
-      "Get an instance of OperationType for the given kind in the given "
-      "context",
-      nb::arg("cls"), nb::arg("operation_name"),
-      nb::arg("context") = nb::none());
-  operationType.def_property_readonly(
-      "operation_name",
-      [](MlirType type) {
-        MlirStringRef operationName =
-            mlirTransformOperationTypeGetOperationName(type);
-        return nb::str(operationName.data, operationName.length);
-      },
-      "Get the name of the payload operation accepted by the handle.");
-
-  //===-------------------------------------------------------------------===//
-  // ParamType
-  //===-------------------------------------------------------------------===//
-
-  auto paramType =
-      mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
-                         mlirTransformParamTypeGetTypeID);
-  paramType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirType type, MlirContext ctx) {
-        return cls(mlirTransformParamTypeGet(ctx, type));
-      },
-      "Get an instance of ParamType for the given type in the given context.",
-      nb::arg("cls"), nb::arg("type"), nb::arg("context") = nb::none());
-  paramType.def_property_readonly(
-      "type",
-      [](MlirType type) {
-        MlirType paramType = mlirTransformParamTypeGetType(type);
-        return paramType;
-      },
-      "Get the type this ParamType is associated with.");
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace transform {
+//===-------------------------------------------------------------------===//
+// AnyOpType
+//===-------------------------------------------------------------------===//
+
+struct AnyOpType : PyConcreteType<AnyOpType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyOpType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirTransformAnyOpTypeGetTypeID;
+  static constexpr const char *pyClassName = "AnyOpType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return AnyOpType(context->getRef(),
+                           mlirTransformAnyOpTypeGet(context.get()->get()));
+        },
+        "Get an instance of AnyOpType in the given context.",
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// AnyParamType
+//===-------------------------------------------------------------------===//
+
+struct AnyParamType : PyConcreteType<AnyParamType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyParamType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirTransformAnyParamTypeGetTypeID;
+  static constexpr const char *pyClassName = "AnyParamType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return AnyParamType(context->getRef(), mlirTransformAnyParamTypeGet(
+                                                     context.get()->get()));
+        },
+        "Get an instance of AnyParamType in the given context.",
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// AnyValueType
+//===-------------------------------------------------------------------===//
+
+struct AnyValueType : PyConcreteType<AnyValueType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyValueType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirTransformAnyValueTypeGetTypeID;
+  static constexpr const char *pyClassName = "AnyValueType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return AnyValueType(context->getRef(), mlirTransformAnyValueTypeGet(
+                                                     context.get()->get()));
+        },
+        "Get an instance of AnyValueType in the given context.",
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// OperationType
+//===-------------------------------------------------------------------===//
+
+struct OperationType : PyConcreteType<OperationType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsATransformOperationType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirTransformOperationTypeGetTypeID;
+  static constexpr const char *pyClassName = "OperationType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const std::string &operationName, DefaultingPyMlirContext context) {
+          MlirStringRef cOperationName =
+              mlirStringRefCreate(operationName.data(), operationName.size());
+          return OperationType(context->getRef(),
+                               mlirTransformOperationTypeGet(
+                                   context.get()->get(), cOperationName));
+        },
+        "Get an instance of OperationType for the given kind in the given "
+        "context",
+        nb::arg("operation_name"), nb::arg("context").none() = nb::none());
+    c.def_prop_ro(
+        "operation_name",
+        [](const PyType &type) {
+          MlirStringRef operationName =
+              mlirTransformOperationTypeGetOperationName(type);
+          return nb::str(operationName.data, operationName.length);
+        },
+        "Get the name of the payload operation accepted by the handle.");
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// ParamType
+//===-------------------------------------------------------------------===//
+
+struct ParamType : PyConcreteType<ParamType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformParamType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirTransformParamTypeGetTypeID;
+  static constexpr const char *pyClassName = "ParamType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const PyType &type, DefaultingPyMlirContext context) {
+          return ParamType(context->getRef(), mlirTransformParamTypeGet(
+                                                  context.get()->get(), type));
+        },
+        "Get an instance of ParamType for the given type in the given context.",
+        nb::arg("type"), nb::arg("context").none() = nb::none());
+    c.def_prop_ro(
+        "type",
+        [](PyType type) {
+          return PyType(type.getContext(), mlirTransformParamTypeGetType(type));
+        },
+        "Get the type this ParamType is associated with.");
+  }
+};
+
+static void populateDialectTransformSubmodule(nb::module_ &m) {
+  AnyOpType::bind(m);
+  AnyParamType::bind(m);
+  AnyValueType::bind(m);
+  OperationType::bind(m);
+  ParamType::bind(m);
 }
+} // namespace transform
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
 NB_MODULE(_mlirDialectsTransform, m) {
   m.doc() = "MLIR Transform dialect.";
-  populateDialectTransformSubmodule(m);
+  mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::transform::
+      populateDialectTransformSubmodule(m);
 }
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8d045cad7a4a3..b4d19878056db 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -43,8 +43,9 @@ def __init__(
         self.parent = parent
         self.children = children if children is not None else []
 
- at ir.register_value_caster(AnyOpType.get_static_typeid())
- at ir.register_value_caster(OperationType.get_static_typeid())
+
+ at ir.register_value_caster(AnyOpType.static_typeid)
+ at ir.register_value_caster(OperationType.static_typeid)
 class OpHandle(Handle):
     """
     Wrapper around a transform operation handle with methods to chain further
@@ -132,8 +133,8 @@ def print(self, name: Optional[str] = None) -> "OpHandle":
         return self
 
 
- at ir.register_value_caster(AnyParamType.get_static_typeid())
- at ir.register_value_caster(ParamType.get_static_typeid())
+ at ir.register_value_caster(AnyParamType.static_typeid)
+ at ir.register_value_caster(ParamType.static_typeid)
 class ParamHandle(Handle):
     """Wrapper around a transform param handle."""
 
@@ -147,7 +148,7 @@ def __init__(
         super().__init__(v, parent=parent, children=children)
 
 
- at ir.register_value_caster(AnyValueType.get_static_typeid())
+ at ir.register_value_caster(AnyValueType.static_typeid)
 class ValueHandle(Handle):
     """
     Wrapper around a transform value handle with methods to chain further
diff --git a/mlir/test/python/dialects/pdl_types.py b/mlir/test/python/dialects/pdl_types.py
index dfba2a36b8980..f75428d295c9c 100644
--- a/mlir/test/python/dialects/pdl_types.py
+++ b/mlir/test/python/dialects/pdl_types.py
@@ -5,149 +5,149 @@
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: test_attribute_type
 @run
 def test_attribute_type():
-  with Context():
-    parsedType = Type.parse("!pdl.attribute")
-    constructedType = pdl.AttributeType.get()
+    with Context():
+        parsedType = Type.parse("!pdl.attribute")
+        constructedType = pdl.AttributeType.get()
 
-    assert pdl.AttributeType.isinstance(parsedType)
-    assert not pdl.OperationType.isinstance(parsedType)
-    assert not pdl.RangeType.isinstance(parsedType)
-    assert not pdl.TypeType.isinstance(parsedType)
-    assert not pdl.ValueType.isinstance(parsedType)
+        assert pdl.AttributeType.isinstance(parsedType)
+        assert not pdl.OperationType.isinstance(parsedType)
+        assert not pdl.RangeType.isinstance(parsedType)
+        assert not pdl.TypeType.isinstance(parsedType)
+        assert not pdl.ValueType.isinstance(parsedType)
 
-    assert pdl.AttributeType.isinstance(constructedType)
-    assert not pdl.OperationType.isinstance(constructedType)
-    assert not pdl.RangeType.isinstance(constructedType)
-    assert not pdl.TypeType.isinstance(constructedType)
-    assert not pdl.ValueType.isinstance(constructedType)
+        assert pdl.AttributeType.isinstance(constructedType)
+        assert not pdl.OperationType.isinstance(constructedType)
+        assert not pdl.RangeType.isinstance(constructedType)
+        assert not pdl.TypeType.isinstance(constructedType)
+        assert not pdl.ValueType.isinstance(constructedType)
 
-    assert parsedType == constructedType
+        assert parsedType == constructedType
 
-    # CHECK: !pdl.attribute
-    print(parsedType)
-    # CHECK: !pdl.attribute
-    print(constructedType)
+        # CHECK: !pdl.attribute
+        print(parsedType)
+        # CHECK: !pdl.attribute
+        print(constructedType)
 
 
 # CHECK-LABEL: TEST: test_operation_type
 @run
 def test_operation_type():
-  with Context():
-    parsedType = Type.parse("!pdl.operation")
-    constructedType = pdl.OperationType.get()
+    with Context():
+        parsedType = Type.parse("!pdl.operation")
+        constructedType = pdl.OperationType.get()
 
-    assert not pdl.AttributeType.isinstance(parsedType)
-    assert pdl.OperationType.isinstance(parsedType)
-    assert not pdl.RangeType.isinstance(parsedType)
-    assert not pdl.TypeType.isinstance(parsedType)
-    assert not pdl.ValueType.isinstance(parsedType)
+        assert not pdl.AttributeType.isinstance(parsedType)
+        assert pdl.OperationType.isinstance(parsedType)
+        assert not pdl.RangeType.isinstance(parsedType)
+        assert not pdl.TypeType.isinstance(parsedType)
+        assert not pdl.ValueType.isinstance(parsedType)
 
-    assert not pdl.AttributeType.isinstance(constructedType)
-    assert pdl.OperationType.isinstance(constructedType)
-    assert not pdl.RangeType.isinstance(constructedType)
-    assert not pdl.TypeType.isinstance(constructedType)
-    assert not pdl.ValueType.isinstance(constructedType)
+        assert not pdl.AttributeType.isinstance(constructedType)
+        assert pdl.OperationType.isinstance(constructedType)
+        assert not pdl.RangeType.isinstance(constructedType)
+        assert not pdl.TypeType.isinstance(constructedType)
+        assert not pdl.ValueType.isinstance(constructedType)
 
-    assert parsedType == constructedType
+        assert parsedType == constructedType
 
-    # CHECK: !pdl.operation
-    print(parsedType)
-    # CHECK: !pdl.operation
-    print(constructedType)
+        # CHECK: !pdl.operation
+        print(parsedType)
+        # CHECK: !pdl.operation
+        print(constructedType)
 
 
 # CHECK-LABEL: TEST: test_range_type
 @run
 def test_range_type():
-  with Context():
-    typeType = Type.parse("!pdl.type")
-    parsedType = Type.parse("!pdl.range<type>")
-    constructedType = pdl.RangeType.get(typeType)
-    elementType = constructedType.element_type
-
-    assert not pdl.AttributeType.isinstance(parsedType)
-    assert not pdl.OperationType.isinstance(parsedType)
-    assert pdl.RangeType.isinstance(parsedType)
-    assert not pdl.TypeType.isinstance(parsedType)
-    assert not pdl.ValueType.isinstance(parsedType)
-
-    assert not pdl.AttributeType.isinstance(constructedType)
-    assert not pdl.OperationType.isinstance(constructedType)
-    assert pdl.RangeType.isinstance(constructedType)
-    assert not pdl.TypeType.isinstance(constructedType)
-    assert not pdl.ValueType.isinstance(constructedType)
-
-    assert parsedType == constructedType
-    assert elementType == typeType
-
-    # CHECK: !pdl.range<type>
-    print(parsedType)
-    # CHECK: !pdl.range<type>
-    print(constructedType)
-    # CHECK: !pdl.type
-    print(elementType)
+    with Context():
+        typeType = Type.parse("!pdl.type")
+        parsedType = Type.parse("!pdl.range<type>")
+        constructedType = pdl.RangeType.get(typeType)
+        elementType = constructedType.element_type
+
+        assert not pdl.AttributeType.isinstance(parsedType)
+        assert not pdl.OperationType.isinstance(parsedType)
+        assert pdl.RangeType.isinstance(parsedType)
+        assert not pdl.TypeType.isinstance(parsedType)
+        assert not pdl.ValueType.isinstance(parsedType)
+
+        assert not pdl.AttributeType.isinstance(constructedType)
+        assert not pdl.OperationType.isinstance(constructedType)
+        assert pdl.RangeType.isinstance(constructedType)
+        assert not pdl.TypeType.isinstance(constructedType)
+        assert not pdl.ValueType.isinstance(constructedType)
+
+        assert parsedType == constructedType
+        assert elementType == typeType
+
+        # CHECK: !pdl.range<type>
+        print(parsedType)
+        # CHECK: !pdl.range<type>
+        print(constructedType)
+        # CHECK: !pdl.type
+        print(elementType)
 
 
 # CHECK-LABEL: TEST: test_type_type
 @run
 def test_type_type():
-  with Context():
-    parsedType = Type.parse("!pdl.type")
-    constructedType = pdl.TypeType.get()
+    with Context():
+        parsedType = Type.parse("!pdl.type")
+        constructedType = pdl.TypeType.get()
 
-    assert not pdl.AttributeType.isinstance(parsedType)
-    assert not pdl.OperationType.isinstance(parsedType)
-    assert not pdl.RangeType.isinstance(parsedType)
-    assert pdl.TypeType.isinstance(parsedType)
-    assert not pdl.ValueType.isinstance(parsedType)
+        assert not pdl.AttributeType.isinstance(parsedType)
+        assert not pdl.OperationType.isinstance(parsedType)
+        assert not pdl.RangeType.isinstance(parsedType)
+        assert pdl.TypeType.isinstance(parsedType)
+        assert not pdl.ValueType.isinstance(parsedType)
 
-    assert not pdl.AttributeType.isinstance(constructedType)
-    assert not pdl.OperationType.isinstance(constructedType)
-    assert not pdl.RangeType.isinstance(constructedType)
-    assert pdl.TypeType.isinstance(constructedType)
-    assert not pdl.ValueType.isinstance(constructedType)
+        assert not pdl.AttributeType.isinstance(constructedType)
+        assert not pdl.OperationType.isinstance(constructedType)
+        assert not pdl.RangeType.isinstance(constructedType)
+        assert pdl.TypeType.isinstance(constructedType)
+        assert not pdl.ValueType.isinstance(constructedType)
 
-    assert parsedType == constructedType
+        assert parsedType == constructedType
 
-    # CHECK: !pdl.type
-    print(parsedType)
-    # CHECK: !pdl.type
-    print(constructedType)
+        # CHECK: !pdl.type
+        print(parsedType)
+        # CHECK: !pdl.type
+        print(constructedType)
 
 
 # CHECK-LABEL: TEST: test_value_type
 @run
 def test_value_type():
-  with Context():
-    parsedType = Type.parse("!pdl.value")
-    constructedType = pdl.ValueType.get()
+    with Context():
+        parsedType = Type.parse("!pdl.value")
+        constructedType = pdl.ValueType.get()
 
-    assert not pdl.AttributeType.isinstance(parsedType)
-    assert not pdl.OperationType.isinstance(parsedType)
-    assert not pdl.RangeType.isinstance(parsedType)
-    assert not pdl.TypeType.isinstance(parsedType)
-    assert pdl.ValueType.isinstance(parsedType)
+        assert not pdl.AttributeType.isinstance(parsedType)
+        assert not pdl.OperationType.isinstance(parsedType)
+        assert not pdl.RangeType.isinstance(parsedType)
+        assert not pdl.TypeType.isinstance(parsedType)
+        assert pdl.ValueType.isinstance(parsedType)
 
-    assert not pdl.AttributeType.isinstance(constructedType)
-    assert not pdl.OperationType.isinstance(constructedType)
-    assert not pdl.RangeType.isinstance(constructedType)
-    assert not pdl.TypeType.isinstance(constructedType)
-    assert pdl.ValueType.isinstance(constructedType)
+        assert not pdl.AttributeType.isinstance(constructedType)
+        assert not pdl.OperationType.isinstance(constructedType)
+        assert not pdl.RangeType.isinstance(constructedType)
+        assert not pdl.TypeType.isinstance(constructedType)
+        assert pdl.ValueType.isinstance(constructedType)
 
-    assert parsedType == constructedType
+        assert parsedType == constructedType
 
-    # CHECK: !pdl.value
-    print(parsedType)
-    # CHECK: !pdl.value
-    print(constructedType)
+        # CHECK: !pdl.value
+        print(parsedType)
+        # CHECK: !pdl.value
+        print(constructedType)
 
 
 # CHECK-LABEL: TEST: test_type_without_context
@@ -157,7 +157,10 @@ def test_type_without_context():
     # should raise an exception but not crash.
     try:
         constructedType = pdl.ValueType.get()
-    except TypeError:
-        pass
+    except RuntimeError as e:
+        assert (
+            "An MLIR function requires a Context but none was provided in the call or from the surrounding environment"
+            in e.args[0]
+        )
     else:
         assert False, "Expected TypeError to be raised."

>From dd70894dc8bd302c45bcae9c23e4f033f4a68d1e Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 3 Jan 2026 18:55:39 -0800
Subject: [PATCH 3/3] update signatures

---
 mlir/lib/Bindings/Python/DialectGPU.cpp       |  40 +++---
 mlir/lib/Bindings/Python/DialectLLVM.cpp      |  62 +++++----
 mlir/lib/Bindings/Python/DialectPDL.cpp       |   6 +-
 mlir/lib/Bindings/Python/DialectQuant.cpp     |  68 ++++++----
 mlir/lib/Bindings/Python/DialectSMT.cpp       |   6 +-
 .../Bindings/Python/DialectSparseTensor.cpp   |  32 ++---
 mlir/lib/Bindings/Python/DialectTransform.cpp |   4 +-
 mlir/lib/Bindings/Python/Rewrite.cpp          | 124 ++++++------------
 8 files changed, 163 insertions(+), 179 deletions(-)

diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 3ea8edec7b136..469fd524e8942 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -13,6 +13,8 @@
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
+#include <mlir/Bindings/Python/IRAttributes.h>
+
 namespace nb = nanobind;
 using namespace nanobind::literals;
 using namespace mlir::python::nanobind_adaptors;
@@ -54,9 +56,9 @@ struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](MlirAttribute target, uint32_t format, const nb::bytes &object,
-           std::optional<MlirAttribute> mlirObjectProps,
-           std::optional<MlirAttribute> mlirKernelsAttr,
+        [](const PyAttribute &target, uint32_t format, const nb::bytes &object,
+           std::optional<PyDictAttribute> mlirObjectProps,
+           std::optional<PyAttribute> mlirKernelsAttr,
            DefaultingPyMlirContext context) {
           MlirStringRef objectStrRef = mlirStringRefCreate(
               static_cast<char *>(const_cast<void *>(object.data())),
@@ -74,26 +76,30 @@ struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
         "kernels"_a = nb::none(), "context"_a = nb::none(),
         "Gets a gpu.object from parameters.");
 
-    c.def_prop_ro("target", [](MlirAttribute self) {
-      return mlirGPUObjectAttrGetTarget(self);
+    c.def_prop_ro("target", [](ObjectAttr &self) {
+      return PyAttribute(self.getContext(), mlirGPUObjectAttrGetTarget(self));
     });
-    c.def_prop_ro("format", [](MlirAttribute self) {
+    c.def_prop_ro("format", [](const ObjectAttr &self) {
       return mlirGPUObjectAttrGetFormat(self);
     });
-    c.def_prop_ro("object", [](MlirAttribute self) {
+    c.def_prop_ro("object", [](const ObjectAttr &self) {
       MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
       return nb::bytes(stringRef.data, stringRef.length);
     });
-    c.def_prop_ro("properties", [](MlirAttribute self) -> nb::object {
-      if (mlirGPUObjectAttrHasProperties(self))
-        return nb::cast(mlirGPUObjectAttrGetProperties(self));
-      return nb::none();
-    });
-    c.def_prop_ro("kernels", [](MlirAttribute self) -> nb::object {
-      if (mlirGPUObjectAttrHasKernels(self))
-        return nb::cast(mlirGPUObjectAttrGetKernels(self));
-      return nb::none();
-    });
+    c.def_prop_ro(
+        "properties", [](ObjectAttr &self) -> std::optional<PyDictAttribute> {
+          if (mlirGPUObjectAttrHasProperties(self))
+            return PyDictAttribute(self.getContext(),
+                                   mlirGPUObjectAttrGetProperties(self));
+          return std::nullopt;
+        });
+    c.def_prop_ro("kernels",
+                  [](ObjectAttr &self) -> std::optional<PyAttribute> {
+                    if (mlirGPUObjectAttrHasKernels(self))
+                      return PyAttribute(self.getContext(),
+                                         mlirGPUObjectAttrGetKernels(self));
+                    return std::nullopt;
+                  });
   }
 };
 } // namespace gpu
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index d4eb078c0f55c..ff31398225a9c 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -42,13 +42,16 @@ struct StructType : PyConcreteType<StructType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get_literal",
-        [](const std::vector<MlirType> &elements, bool packed, MlirLocation loc,
+        [](const std::vector<PyType> &elements, bool packed, MlirLocation loc,
            DefaultingPyMlirContext context) {
           python::CollectDiagnosticsToStringScope scope(
               mlirLocationGetContext(loc));
+          std::vector<MlirType> elements_(elements.size());
+          std::transform(elements.begin(), elements.end(), elements_.begin(),
+                         [](const PyType &elem) { return elem; });
 
           MlirType type = mlirLLVMStructTypeLiteralGetChecked(
-              loc, elements.size(), elements.data(), packed);
+              loc, elements.size(), elements_.data(), packed);
           if (mlirTypeIsNull(type)) {
             throw nb::value_error(scope.takeMessage().c_str());
           }
@@ -59,12 +62,16 @@ struct StructType : PyConcreteType<StructType> {
 
     c.def_static(
         "get_literal_unchecked",
-        [](const std::vector<MlirType> &elements, bool packed,
+        [](const std::vector<PyType> &elements, bool packed,
            DefaultingPyMlirContext context) {
           python::CollectDiagnosticsToStringScope scope(context.get()->get());
 
+          std::vector<MlirType> elements_(elements.size());
+          std::transform(elements.begin(), elements.end(), elements_.begin(),
+                         [](const PyType &elem) { return elem; });
+
           MlirType type = mlirLLVMStructTypeLiteralGet(
-              context.get()->get(), elements.size(), elements.data(), packed);
+              context.get()->get(), elements.size(), elements_.data(), packed);
           if (mlirTypeIsNull(type)) {
             throw nb::value_error(scope.takeMessage().c_str());
           }
@@ -95,9 +102,13 @@ struct StructType : PyConcreteType<StructType> {
 
     c.def(
         "set_body",
-        [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+        [](const StructType &self, const std::vector<PyType> &elements,
+           bool packed) {
+          std::vector<MlirType> elements_(elements.size());
+          std::transform(elements.begin(), elements.end(), elements_.begin(),
+                         [](const PyType &elem) { return elem; });
           MlirLogicalResult result = mlirLLVMStructTypeSetBody(
-              self, elements.size(), elements.data(), packed);
+              self, elements.size(), elements_.data(), packed);
           if (!mlirLogicalResultIsSuccess(result)) {
             throw nb::value_error(
                 "Struct body already set to different content.");
@@ -107,26 +118,30 @@ struct StructType : PyConcreteType<StructType> {
 
     c.def_static(
         "new_identified",
-        [](const std::string &name, const std::vector<MlirType> &elements,
+        [](const std::string &name, const std::vector<PyType> &elements,
            bool packed, DefaultingPyMlirContext context) {
+          std::vector<MlirType> elements_(elements.size());
+          std::transform(elements.begin(), elements.end(), elements_.begin(),
+                         [](const PyType &elem) { return elem; });
           return StructType(context->getRef(),
                             mlirLLVMStructTypeIdentifiedNewGet(
                                 context.get()->get(),
                                 mlirStringRefCreate(name.data(), name.length()),
-                                elements.size(), elements.data(), packed));
+                                elements.size(), elements_.data(), packed));
         },
         "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
         "context"_a = nb::none());
 
-    c.def_prop_ro("name", [](PyType type) -> std::optional<std::string> {
-      if (mlirLLVMStructTypeIsLiteral(type))
-        return std::nullopt;
+    c.def_prop_ro(
+        "name", [](const StructType &type) -> std::optional<std::string> {
+          if (mlirLLVMStructTypeIsLiteral(type))
+            return std::nullopt;
 
-      MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
-      return StringRef(stringRef.data, stringRef.length).str();
-    });
+          MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+          return StringRef(stringRef.data, stringRef.length).str();
+        });
 
-    c.def_prop_ro("body", [](PyType type) -> nb::object {
+    c.def_prop_ro("body", [](const StructType &type) -> nb::object {
       // Don't crash in absence of a body.
       if (mlirLLVMStructTypeIsOpaque(type))
         return nb::none();
@@ -139,11 +154,13 @@ struct StructType : PyConcreteType<StructType> {
       return body;
     });
 
-    c.def_prop_ro("packed",
-                  [](PyType type) { return mlirLLVMStructTypeIsPacked(type); });
+    c.def_prop_ro("packed", [](const StructType &type) {
+      return mlirLLVMStructTypeIsPacked(type);
+    });
 
-    c.def_prop_ro("opaque",
-                  [](PyType type) { return mlirLLVMStructTypeIsOpaque(type); });
+    c.def_prop_ro("opaque", [](const StructType &type) {
+      return mlirLLVMStructTypeIsOpaque(type);
+    });
   }
 };
 
@@ -174,7 +191,7 @@ struct PointerType : PyConcreteType<PointerType> {
         },
         "address_space"_a = nb::none(), nb::kw_only(),
         "context"_a = nb::none());
-    c.def_prop_ro("address_space", [](PyType type) {
+    c.def_prop_ro("address_space", [](const PointerType &type) {
       return mlirLLVMPointerTypeGetAddressSpace(type);
     });
   }
@@ -186,12 +203,9 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
 
   m.def(
       "translate_module_to_llvmir",
-      [](MlirOperation module) {
+      [](const PyOperation &module) {
         return mlirTranslateModuleToLLVMIRToString(module);
       },
-      // clang-format off
-      nb::sig("def translate_module_to_llvmir(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> str"),
-      // clang-format on
       "module"_a, nb::rv_policy::take_ownership);
 }
 } // namespace llvm
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index d2ed3b141d724..5bb51eb63ce56 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -87,7 +87,7 @@ struct RangeType : PyConcreteType<RangeType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyType &elementType, DefaultingPyMlirContext context) {
+        [](const PyType &elementType, DefaultingPyMlirContext context) {
           return RangeType(context->getRef(), mlirPDLRangeTypeGet(elementType));
         },
         "Gets an instance of RangeType in the same context as the provided "
@@ -95,12 +95,10 @@ struct RangeType : PyConcreteType<RangeType> {
         nb::arg("element_type"), nb::arg("context").none() = nb::none());
     c.def_prop_ro(
         "element_type",
-        [](PyType &type) {
+        [](RangeType &type) {
           return PyType(type.getContext(),
                         mlirPDLRangeTypeGetElementType(type));
         },
-        nb::sig(
-            "def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")),
         "Get the element type.");
   }
 };
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index a1e0a281a708d..3a9b8ffdf8971 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -14,6 +14,8 @@
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
+#include <mlir/Bindings/Python/IRAttributes.h>
+
 namespace nb = nanobind;
 using namespace llvm;
 using namespace mlir::python::nanobind_adaptors;
@@ -54,48 +56,52 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
         nb::arg("is_signed"), nb::arg("integral_width"));
     c.def_prop_ro(
         "expressed_type",
-        [](PyType type) {
+        [](QuantizedType &type) {
           return PyType(type.getContext(),
                         mlirQuantizedTypeGetExpressedType(type));
         },
         "Type expressed by this quantized type.");
     c.def_prop_ro(
         "flags",
-        [](const PyType &type) { return mlirQuantizedTypeGetFlags(type); },
+        [](const QuantizedType &type) {
+          return mlirQuantizedTypeGetFlags(type);
+        },
         "Flags of this quantized type (named accessors should be preferred to "
         "this)");
     c.def_prop_ro(
         "is_signed",
-        [](const PyType &type) { return mlirQuantizedTypeIsSigned(type); },
+        [](const QuantizedType &type) {
+          return mlirQuantizedTypeIsSigned(type);
+        },
         "Signedness of this quantized type.");
     c.def_prop_ro(
         "storage_type",
-        [](PyType type) {
+        [](QuantizedType &type) {
           return PyType(type.getContext(),
                         mlirQuantizedTypeGetStorageType(type));
         },
         "Storage type backing this quantized type.");
     c.def_prop_ro(
         "storage_type_min",
-        [](const PyType &type) {
+        [](const QuantizedType &type) {
           return mlirQuantizedTypeGetStorageTypeMin(type);
         },
         "The minimum value held by the storage type of this quantized type.");
     c.def_prop_ro(
         "storage_type_max",
-        [](const PyType &type) {
+        [](const QuantizedType &type) {
           return mlirQuantizedTypeGetStorageTypeMax(type);
         },
         "The maximum value held by the storage type of this quantized type.");
     c.def_prop_ro(
         "storage_type_integral_width",
-        [](const PyType &type) {
+        [](const QuantizedType &type) {
           return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
         },
         "The bitwidth of the storage type of this quantized type.");
     c.def(
         "is_compatible_expressed_type",
-        [](const PyType &type, const PyType &candidate) {
+        [](const QuantizedType &type, const PyType &candidate) {
           return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
         },
         "Checks whether the candidate type can be expressed by this quantized "
@@ -103,14 +109,14 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
         nb::arg("candidate"));
     c.def_prop_ro(
         "quantized_element_type",
-        [](PyType type) {
+        [](QuantizedType &type) {
           return PyType(type.getContext(),
                         mlirQuantizedTypeGetQuantizedElementType(type));
         },
         "Element type of this quantized type expressed as quantized type.");
     c.def(
         "cast_from_storage_type",
-        [](PyType type, const PyType &candidate) {
+        [](QuantizedType &type, const PyType &candidate) {
           MlirType castResult =
               mlirQuantizedTypeCastFromStorageType(type, candidate);
           if (!mlirTypeIsNull(castResult))
@@ -125,10 +131,10 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
         nb::arg("candidate"));
     c.def_static(
         "cast_to_storage_type",
-        [](const PyType &type) {
+        [](QuantizedType &type) {
           MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
           if (!mlirTypeIsNull(castResult))
-            return castResult;
+            return PyType(type.getContext(), castResult);
           throw nb::type_error("Invalid cast.");
         },
         "Casts from a type based on a quantized type to a corresponding type "
@@ -137,7 +143,7 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
         nb::arg("type"));
     c.def(
         "cast_from_expressed_type",
-        [](PyType type, const PyType &candidate) {
+        [](QuantizedType &type, const PyType &candidate) {
           MlirType castResult =
               mlirQuantizedTypeCastFromExpressedType(type, candidate);
           if (!mlirTypeIsNull(castResult))
@@ -151,10 +157,10 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
         nb::arg("candidate"));
     c.def_static(
         "cast_to_expressed_type",
-        [](const PyType &type) {
+        [](QuantizedType &type) {
           MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
           if (!mlirTypeIsNull(castResult))
-            return castResult;
+            return PyType(type.getContext(), castResult);
           throw nb::type_error("Invalid cast.");
         },
         "Casts from a type based on a quantized type to a corresponding type "
@@ -164,7 +170,7 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
         nb::arg("type"));
     c.def(
         "cast_expressed_to_storage_type",
-        [](PyType type, const PyType &candidate) {
+        [](QuantizedType &type, const PyType &candidate) {
           MlirType castResult =
               mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
           if (!mlirTypeIsNull(castResult))
@@ -238,21 +244,21 @@ struct UniformQuantizedType
         nb::arg("storage_type_max"), nb::arg("context") = nb::none());
     c.def_prop_ro(
         "scale",
-        [](const PyType &type) {
+        [](const UniformQuantizedType &type) {
           return mlirUniformQuantizedTypeGetScale(type);
         },
         "The scale designates the difference between the real values "
         "corresponding to consecutive quantized values differing by 1.");
     c.def_prop_ro(
         "zero_point",
-        [](const PyType &type) {
+        [](const UniformQuantizedType &type) {
           return mlirUniformQuantizedTypeGetZeroPoint(type);
         },
         "The storage value corresponding to the real value 0 in the affine "
         "equation.");
     c.def_prop_ro(
         "is_fixed_point",
-        [](const PyType &type) {
+        [](const UniformQuantizedType &type) {
           return mlirUniformQuantizedTypeIsFixedPoint(type);
         },
         "Fixed point values are real numbers divided by a scale.");
@@ -298,7 +304,7 @@ struct UniformQuantizedPerAxisType
         nb::arg("storage_type_max"), nb::arg("context") = nb::none());
     c.def_prop_ro(
         "scales",
-        [](const PyType &type) {
+        [](const UniformQuantizedPerAxisType &type) {
           intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
           std::vector<double> scales;
           scales.reserve(nDim);
@@ -313,7 +319,7 @@ struct UniformQuantizedPerAxisType
         "scale corresponds to the ith slice in the quantized_dimension.");
     c.def_prop_ro(
         "zero_points",
-        [](const PyType &type) {
+        [](const UniformQuantizedPerAxisType &type) {
           intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
           std::vector<int64_t> zeroPoints;
           zeroPoints.reserve(nDim);
@@ -329,14 +335,14 @@ struct UniformQuantizedPerAxisType
         "quantized_dimension.");
     c.def_prop_ro(
         "quantized_dimension",
-        [](const PyType &type) {
+        [](const UniformQuantizedPerAxisType &type) {
           return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
         },
         "Specifies the dimension of the shape that the scales and zero points "
         "correspond to.");
     c.def_prop_ro(
         "is_fixed_point",
-        [](const PyType &type) {
+        [](const UniformQuantizedPerAxisType &type) {
           return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
         },
         "Fixed point values are real numbers divided by a scale.");
@@ -379,7 +385,7 @@ struct UniformQuantizedSubChannelType
         nb::arg("context") = nb::none());
     c.def_prop_ro(
         "quantized_dimensions",
-        [](const PyType &type) {
+        [](const UniformQuantizedSubChannelType &type) {
           intptr_t nDim =
               mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
           std::vector<int32_t> quantizedDimensions;
@@ -399,7 +405,7 @@ struct UniformQuantizedSubChannelType
         "i-th block size from block_sizes method.");
     c.def_prop_ro(
         "block_sizes",
-        [](const PyType &type) {
+        [](const UniformQuantizedSubChannelType &type) {
           intptr_t nDim =
               mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
           std::vector<int64_t> blockSizes;
@@ -417,14 +423,18 @@ struct UniformQuantizedSubChannelType
         "in the list returned by quantized_dimensions method.");
     c.def_prop_ro(
         "scales",
-        [](const PyType &type) -> MlirAttribute {
-          return mlirUniformQuantizedSubChannelTypeGetScales(type);
+        [](UniformQuantizedSubChannelType &type) {
+          return PyDenseElementsAttribute(
+              type.getContext(),
+              mlirUniformQuantizedSubChannelTypeGetScales(type));
         },
         "The scales of the quantized type.");
     c.def_prop_ro(
         "zero_points",
-        [](const PyType &type) -> MlirAttribute {
-          return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
+        [](UniformQuantizedSubChannelType &type) {
+          return PyDenseElementsAttribute(
+              type.getContext(),
+              mlirUniformQuantizedSubChannelTypeGetZeroPoints(type));
         },
         "The zero points of the quantized type.");
   }
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 39490155d5216..2c12341b81439 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -97,7 +97,7 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) {
 
   m.def(
       "export_smtlib",
-      [&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues,
+      [&exportSMTLIB](const PyOperation &module, bool inlineSingleUseValues,
                       bool indentLetBody) {
         return exportSMTLIB(module, inlineSingleUseValues, indentLetBody);
       },
@@ -105,9 +105,9 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) {
       "indent_let_body"_a = false);
   m.def(
       "export_smtlib",
-      [&exportSMTLIB](MlirModule module, bool inlineSingleUseValues,
+      [&exportSMTLIB](PyModule &module, bool inlineSingleUseValues,
                       bool indentLetBody) {
-        return exportSMTLIB(mlirModuleGetOperation(module),
+        return exportSMTLIB(mlirModuleGetOperation(module.get()),
                             inlineSingleUseValues, indentLetBody);
       },
       "module"_a, "inline_single_use_values"_a = false,
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 6ec58dd88d24f..ca197ba32e074 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -52,10 +52,10 @@ struct EncodingAttr : PyConcreteAttribute<EncodingAttr> {
     c.def_static(
         "get",
         [](std::vector<MlirSparseTensorLevelType> lvlTypes,
-           std::optional<MlirAffineMap> dimToLvl,
-           std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
-           std::optional<MlirAttribute> explicitVal,
-           std::optional<MlirAttribute> implicitVal,
+           std::optional<PyAffineMap> dimToLvl,
+           std::optional<PyAffineMap> lvlToDim, int posWidth, int crdWidth,
+           std::optional<PyAttribute> explicitVal,
+           std::optional<PyAttribute> implicitVal,
            DefaultingPyMlirContext context) {
           return EncodingAttr(
               context->getRef(),
@@ -93,7 +93,7 @@ struct EncodingAttr : PyConcreteAttribute<EncodingAttr> {
         nb::arg("n") = 0, nb::arg("m") = 0,
         "Builds a sparse_tensor.encoding.level_type from parameters.");
 
-    c.def_prop_ro("lvl_types", [](MlirAttribute self) {
+    c.def_prop_ro("lvl_types", [](const EncodingAttr &self) {
       const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
       std::vector<MlirSparseTensorLevelType> ret;
       ret.reserve(lvlRank);
@@ -103,53 +103,53 @@ struct EncodingAttr : PyConcreteAttribute<EncodingAttr> {
     });
 
     c.def_prop_ro(
-        "dim_to_lvl", [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+        "dim_to_lvl", [](EncodingAttr &self) -> std::optional<PyAffineMap> {
           MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
           if (mlirAffineMapIsNull(ret))
             return {};
-          return ret;
+          return PyAffineMap(self.getContext(), ret);
         });
 
     c.def_prop_ro(
-        "lvl_to_dim", [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+        "lvl_to_dim", [](EncodingAttr &self) -> std::optional<PyAffineMap> {
           MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
           if (mlirAffineMapIsNull(ret))
             return {};
-          return ret;
+          return PyAffineMap(self.getContext(), ret);
         });
 
     c.def_prop_ro("pos_width", mlirSparseTensorEncodingAttrGetPosWidth);
     c.def_prop_ro("crd_width", mlirSparseTensorEncodingAttrGetCrdWidth);
 
     c.def_prop_ro(
-        "explicit_val", [](MlirAttribute self) -> std::optional<MlirAttribute> {
+        "explicit_val", [](EncodingAttr &self) -> std::optional<PyAttribute> {
           MlirAttribute ret = mlirSparseTensorEncodingAttrGetExplicitVal(self);
           if (mlirAttributeIsNull(ret))
             return {};
-          return ret;
+          return PyAttribute(self.getContext(), ret);
         });
 
     c.def_prop_ro(
-        "implicit_val", [](MlirAttribute self) -> std::optional<MlirAttribute> {
+        "implicit_val", [](EncodingAttr &self) -> std::optional<PyAttribute> {
           MlirAttribute ret = mlirSparseTensorEncodingAttrGetImplicitVal(self);
           if (mlirAttributeIsNull(ret))
             return {};
-          return ret;
+          return PyAttribute(self.getContext(), ret);
         });
 
-    c.def_prop_ro("structured_n", [](MlirAttribute self) -> unsigned {
+    c.def_prop_ro("structured_n", [](const EncodingAttr &self) -> unsigned {
       const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
       return mlirSparseTensorEncodingAttrGetStructuredN(
           mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
     });
 
-    c.def_prop_ro("structured_m", [](MlirAttribute self) -> unsigned {
+    c.def_prop_ro("structured_m", [](const EncodingAttr &self) -> unsigned {
       const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
       return mlirSparseTensorEncodingAttrGetStructuredM(
           mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
     });
 
-    c.def_prop_ro("lvl_formats_enum", [](MlirAttribute self) {
+    c.def_prop_ro("lvl_formats_enum", [](const EncodingAttr &self) {
       const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
       std::vector<PySparseTensorLevelFormat> ret;
       ret.reserve(lvlRank);
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index f42ebd004d09f..19e6418f067bb 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -118,7 +118,7 @@ struct OperationType : PyConcreteType<OperationType> {
         nb::arg("operation_name"), nb::arg("context").none() = nb::none());
     c.def_prop_ro(
         "operation_name",
-        [](const PyType &type) {
+        [](const OperationType &type) {
           MlirStringRef operationName =
               mlirTransformOperationTypeGetOperationName(type);
           return nb::str(operationName.data, operationName.length);
@@ -149,7 +149,7 @@ struct ParamType : PyConcreteType<ParamType> {
         nb::arg("type"), nb::arg("context").none() = nb::none());
     c.def_prop_ro(
         "type",
-        [](PyType type) {
+        [](ParamType type) {
           return PyType(type.getContext(), mlirTransformParamTypeGetType(type));
         },
         "Get the type this ParamType is associated with.");
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index c282f4b6996e5..f04b9b7788dd0 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -55,7 +55,7 @@ class PyPatternRewriter {
     mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
   }
 
-  void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+  void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); }
 
 private:
   MlirRewriterBase base;
@@ -342,38 +342,30 @@ void populateRewriteSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
   // Mapping of the PatternRewriter
   //----------------------------------------------------------------------------
-  nb::
-      class_<PyPatternRewriter>(m, "PatternRewriter")
-          .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
-                       "The current insertion point of the PatternRewriter.")
-          .def(
-              "replace_op",
-              [](PyPatternRewriter &self, MlirOperation op,
-                 MlirOperation newOp) { self.replaceOp(op, newOp); },
-              "Replace an operation with a new operation.", nb::arg("op"),
-              nb::arg("new_op"),
-              // clang-format off
-              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
-              // clang-format on
-              )
-          .def(
-              "replace_op",
-              [](PyPatternRewriter &self, MlirOperation op,
-                 const std::vector<MlirValue> &values) {
-                self.replaceOp(op, values);
-              },
-              "Replace an operation with a list of values.", nb::arg("op"),
-              nb::arg("values"),
-              // clang-format off
-              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
-              // clang-format on
-              )
-          .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
-               nb::arg("op"),
-               // clang-format off
-                nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
-               // clang-format on
-          );
+  nb::class_<PyPatternRewriter>(m, "PatternRewriter")
+      .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+                   "The current insertion point of the PatternRewriter.")
+      .def(
+          "replace_op",
+          [](PyPatternRewriter &self, PyOperationBase &op,
+             PyOperationBase &newOp) {
+            self.replaceOp(op.getOperation(), newOp.getOperation());
+          },
+          "Replace an operation with a new operation.", nb::arg("op"),
+          nb::arg("new_op"))
+      .def(
+          "replace_op",
+          [](PyPatternRewriter &self, PyOperationBase &op,
+             const std::vector<PyValue> &values) {
+            std::vector<MlirValue> values_(values.size());
+            std::transform(values.begin(), values.end(), values_.begin(),
+                           [](const PyValue &val) { return val; });
+            self.replaceOp(op.getOperation(), values_);
+          },
+          "Replace an operation with a list of values.", nb::arg("op"),
+          nb::arg("values"))
+      .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+           nb::arg("op"));
 
   //----------------------------------------------------------------------------
   // Mapping of the RewritePatternSet
@@ -428,42 +420,21 @@ void populateRewriteSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
   nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
-      .def(
-          "append",
-          [](PyMlirPDLResultList results, const PyValue &value) {
-            mlirPDLResultListPushBackValue(results, value);
-          },
-          // clang-format off
-          nb::sig("def append(self, value: " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")")
-          // clang-format on
-          )
-      .def(
-          "append",
-          [](PyMlirPDLResultList results, const PyOperation &op) {
-            mlirPDLResultListPushBackOperation(results, op);
-          },
-          // clang-format off
-          nb::sig("def append(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")")
-          // clang-format on
-          )
-      .def(
-          "append",
-          [](PyMlirPDLResultList results, const PyType &type) {
-            mlirPDLResultListPushBackType(results, type);
-          },
-          // clang-format off
-          nb::sig("def append(self, type: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")")
-          // clang-format on
-          )
-      .def(
-          "append",
-          [](PyMlirPDLResultList results, const PyAttribute &attr) {
-            mlirPDLResultListPushBackAttribute(results, attr);
-          },
-          // clang-format off
-          nb::sig("def append(self, attr: " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")")
-          // clang-format on
-      );
+      .def("append",
+           [](PyMlirPDLResultList results, const PyValue &value) {
+             mlirPDLResultListPushBackValue(results, value);
+           })
+      .def("append",
+           [](PyMlirPDLResultList results, const PyOperation &op) {
+             mlirPDLResultListPushBackOperation(results, op);
+           })
+      .def("append",
+           [](PyMlirPDLResultList results, const PyType &type) {
+             mlirPDLResultListPushBackType(results, type);
+           })
+      .def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) {
+        mlirPDLResultListPushBackAttribute(results, attr);
+      });
   nb::class_<PyPDLPatternModule>(m, "PDLModule")
       .def(
           "__init__",
@@ -471,9 +442,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
             new (&self) PyPDLPatternModule(
                 mlirPDLPatternModuleFromModule(module.get()));
           },
-          // clang-format off
-          nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
-          // clang-format on
           "module"_a, "Create a PDL module from the given module.")
       .def(
           "__init__",
@@ -481,9 +449,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
             new (&self) PyPDLPatternModule(
                 mlirPDLPatternModuleFromModule(module.get()));
           },
-          // clang-format off
-          nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
-          // clang-format on
           "module"_a, "Create a PDL module from the given module.")
       .def(
           "freeze",
@@ -552,9 +517,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
            throw std::runtime_error("pattern application failed to converge");
        },
        "module"_a, "set"_a,
-       // clang-format off
-       nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
-       // clang-format on
        "Applys the given patterns to the given module greedily while folding "
        "results.")
       .def(
@@ -568,9 +530,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
                   "pattern application failed to converge");
           },
           "op"_a, "set"_a,
-          // clang-format off
-          nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
-          // clang-format on
           "Applys the given patterns to the given op greedily while folding "
           "results.")
       .def(
@@ -579,9 +538,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
             mlirWalkAndApplyPatterns(op.getOperation(), set.get());
           },
           "op"_a, "set"_a,
-          // clang-format off
-          nb::sig("def walk_and_apply_patterns(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
-          // clang-format on
           "Applies the given patterns to the given op by a fast walk-based "
           "driver.");
 }



More information about the llvm-branch-commits mailing list