[Mlir-commits] [mlir] [mlir][Python] port in-tree dialect extensions to use MLIRPythonSupport (PR #174156)

Maksim Levental llvmlistbot at llvm.org
Mon Jan 5 09:59:15 PST 2026


================
@@ -13,163 +13,208 @@
 #include "mlir-c/Support.h"
 #include "mlir-c/Target/LLVMIR.h"
 #include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 
 using namespace nanobind::literals;
-
 using namespace llvm;
 using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
-
-  //===--------------------------------------------------------------------===//
-  // StructType
-  //===--------------------------------------------------------------------===//
-
-  auto llvmStructType = mlir_type_subclass(
-      m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
-
-  llvmStructType
-      .def_classmethod(
-          "get_literal",
-          [](const nb::object &cls, const std::vector<MlirType> &elements,
-             bool packed, MlirLocation loc) {
-            CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
-            MlirType type = mlirLLVMStructTypeLiteralGetChecked(
-                loc, elements.size(), elements.data(), packed);
-            if (mlirTypeIsNull(type)) {
-              throw nb::value_error(scope.takeMessage().c_str());
-            }
-            return cls(type);
-          },
-          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-          "loc"_a = nb::none())
-      .def_classmethod(
-          "get_literal_unchecked",
-          [](const nb::object &cls, const std::vector<MlirType> &elements,
-             bool packed, MlirContext context) {
-            CollectDiagnosticsToStringScope scope(context);
-
-            MlirType type = mlirLLVMStructTypeLiteralGet(
-                context, elements.size(), elements.data(), packed);
-            if (mlirTypeIsNull(type)) {
-              throw nb::value_error(scope.takeMessage().c_str());
-            }
-            return cls(type);
-          },
-          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-          "context"_a = nb::none());
-
-  llvmStructType.def_classmethod(
-      "get_identified",
-      [](const nb::object &cls, const std::string &name, MlirContext context) {
-        return cls(mlirLLVMStructTypeIdentifiedGet(
-            context, mlirStringRefCreate(name.data(), name.size())));
-      },
-      "cls"_a, "name"_a, nb::kw_only(), "context"_a = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace llvm {
+//===--------------------------------------------------------------------===//
+// StructType
+//===--------------------------------------------------------------------===//
+
+struct StructType : PyConcreteType<StructType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMStructType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirLLVMStructTypeGetTypeID;
+  static constexpr const char *pyClassName = "StructType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_literal",
+        [](const std::vector<PyType> &elements, bool packed, MlirLocation loc,
+           DefaultingPyMlirContext context) {
+          python::CollectDiagnosticsToStringScope scope(
+              mlirLocationGetContext(loc));
+          std::vector<MlirType> elements_(elements.size());
+          std::transform(elements.begin(), elements.end(), elements_.begin(),
+                         [](const PyType &elem) { return elem; });
+
+          MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+              loc, elements.size(), elements_.data(), packed);
+          if (mlirTypeIsNull(type)) {
+            throw nb::value_error(scope.takeMessage().c_str());
+          }
+          return StructType(context->getRef(), type);
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false, "loc"_a = nb::none(),
+        "context"_a = nb::none());
+
+    c.def_static(
+        "get_literal_unchecked",
+        [](const std::vector<PyType> &elements, bool packed,
+           DefaultingPyMlirContext context) {
+          python::CollectDiagnosticsToStringScope scope(context.get()->get());
+
+          std::vector<MlirType> elements_(elements.size());
+          std::transform(elements.begin(), elements.end(), elements_.begin(),
+                         [](const PyType &elem) { return elem; });
+
+          MlirType type = mlirLLVMStructTypeLiteralGet(
+              context.get()->get(), elements.size(), elements_.data(), packed);
+          if (mlirTypeIsNull(type)) {
+            throw nb::value_error(scope.takeMessage().c_str());
+          }
+          return StructType(context->getRef(), type);
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false,
+        "context"_a = nb::none());
+
+    c.def_static(
+        "get_identified",
+        [](const std::string &name, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeIdentifiedGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), name.size())));
+        },
+        "name"_a, nb::kw_only(), "context"_a = nb::none());
+
+    c.def_static(
+        "get_opaque",
+        [](const std::string &name, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeOpaqueGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), name.size())));
+        },
+        "name"_a, "context"_a = nb::none());
+
+    c.def(
+        "set_body",
+        [](const StructType &self, const std::vector<PyType> &elements,
+           bool packed) {
+          std::vector<MlirType> elements_(elements.size());
+          std::transform(elements.begin(), elements.end(), elements_.begin(),
+                         [](const PyType &elem) { return elem; });
+          MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+              self, elements.size(), elements_.data(), packed);
+          if (!mlirLogicalResultIsSuccess(result)) {
+            throw nb::value_error(
+                "Struct body already set to different content.");
+          }
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false);
+
+    c.def_static(
+        "new_identified",
+        [](const std::string &name, const std::vector<PyType> &elements,
+           bool packed, DefaultingPyMlirContext context) {
+          std::vector<MlirType> elements_(elements.size());
+          std::transform(elements.begin(), elements.end(), elements_.begin(),
+                         [](const PyType &elem) { return elem; });
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeIdentifiedNewGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), name.length()),
+                                elements.size(), elements_.data(), packed));
+        },
+        "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+        "context"_a = nb::none());
+
+    c.def_prop_ro(
+        "name", [](const StructType &type) -> std::optional<std::string> {
+          if (mlirLLVMStructTypeIsLiteral(type))
+            return std::nullopt;
+
+          MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+          return StringRef(stringRef.data, stringRef.length).str();
+        });
+
+    c.def_prop_ro("body", [](const StructType &type) -> nb::object {
+      // Don't crash in absence of a body.
+      if (mlirLLVMStructTypeIsOpaque(type))
+        return nb::none();
+
+      nb::list body;
+      for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type);
+           i < e; ++i) {
+        body.append(mlirLLVMStructTypeGetElementType(type, i));
+      }
+      return body;
+    });
+
+    c.def_prop_ro("packed", [](const StructType &type) {
+      return mlirLLVMStructTypeIsPacked(type);
+    });
+
+    c.def_prop_ro("opaque", [](const StructType &type) {
+      return mlirLLVMStructTypeIsOpaque(type);
+    });
+  }
+};
+
+//===--------------------------------------------------------------------===//
+// PointerType
+//===--------------------------------------------------------------------===//
+
+struct PointerType : PyConcreteType<PointerType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMPointerType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirLLVMPointerTypeGetTypeID;
+  static constexpr const char *pyClassName = "PointerType";
+  using PyConcreteType::PyConcreteType;
----------------
makslevental wrote:

done

https://github.com/llvm/llvm-project/pull/174156


More information about the Mlir-commits mailing list