[Mlir-commits] [mlir] [mlir][python] auto attribute casting (PR #97786)

Maksim Levental llvmlistbot at llvm.org
Thu Jul 4 20:42:16 PDT 2024


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/97786

This PR implements auto attribute casting for downstream attributes just like we have for downstream types.

Use case: https://github.com/openxla/shardy

cc @bartchr808 

>From 3e5d81020dcb035f8d6dca99569b95e626cbdf0d Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 4 Jul 2024 22:40:14 -0500
Subject: [PATCH] [mlir][python] auto attribute casting

---
 .../mlir/Bindings/Python/PybindAdaptors.h     | 29 +++++++++++++++++--
 mlir/test/python/dialects/python_test.py      | 14 ++++++++-
 mlir/test/python/lib/PythonTestCAPI.cpp       |  4 +++
 mlir/test/python/lib/PythonTestCAPI.h         |  2 ++
 mlir/test/python/lib/PythonTestDialect.h      |  6 ++--
 mlir/test/python/lib/PythonTestModule.cpp     |  7 +++--
 mlir/test/python/python_test_ops.td           |  4 +++
 7 files changed, 56 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index ebf50109f72f2..67cc48277efcb 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -406,21 +406,25 @@ class pure_subclass {
 class mlir_attribute_subclass : public pure_subclass {
 public:
   using IsAFunctionTy = bool (*)(MlirAttribute);
+  using GetTypeIDFunctionTy = MlirTypeID (*)();
 
   /// Subclasses by looking up the super-class dynamically.
   mlir_attribute_subclass(py::handle scope, const char *attrClassName,
-                          IsAFunctionTy isaFunction)
+                          IsAFunctionTy isaFunction,
+                          GetTypeIDFunctionTy getTypeIDFunction = nullptr)
       : mlir_attribute_subclass(
             scope, attrClassName, isaFunction,
             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Attribute")) {}
+                .attr("Attribute"),
+            getTypeIDFunction) {}
 
   /// Subclasses with a provided mlir.ir.Attribute super-class. This must
   /// be used if the subclass is being defined in the same extension module
   /// as the mlir.ir class (otherwise, it will trigger a recursive
   /// initialization).
   mlir_attribute_subclass(py::handle scope, const char *typeClassName,
-                          IsAFunctionTy isaFunction, const py::object &superCls)
+                          IsAFunctionTy isaFunction, const py::object &superCls,
+                          GetTypeIDFunctionTy getTypeIDFunction = nullptr)
       : pure_subclass(scope, typeClassName, superCls) {
     // Casting constructor. Note that it hard, if not impossible, to properly
     // call chain to parent `__init__` in pybind11 due to its special handling
@@ -454,6 +458,25 @@ class mlir_attribute_subclass : public pure_subclass {
         "isinstance",
         [isaFunction](MlirAttribute other) { return isaFunction(other); },
         py::arg("other_attribute"));
+    def("__repr__", [superCls, captureTypeName](py::object self) {
+      return py::repr(superCls(self))
+          .attr("replace")(superCls.attr("__name__"), captureTypeName);
+    });
+    if (getTypeIDFunction) {
+      // 'get_static_typeid' method.
+      // This is modeled as a static method instead of a static property because
+      // `def_property_readonly_static` is not available in `pure_subclass` and
+      // we do not want to introduce the complexity that pybind uses to
+      // implement it.
+      def_staticmethod("get_static_typeid",
+                       [getTypeIDFunction]() { return getTypeIDFunction(); });
+      py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+          .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+              getTypeIDFunction())(pybind11::cpp_function(
+              [thisClass = thisClass](const py::object &mlirAttribute) {
+                return thisClass(mlirAttribute);
+              }));
+    }
   }
 };
 
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 70927b22d4749..a76f3f2b5e458 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -307,11 +307,23 @@ def testOptionalOperandOp():
 # CHECK-LABEL: TEST: testCustomAttribute
 @run
 def testCustomAttribute():
-    with Context() as ctx:
+    with Context() as ctx, Location.unknown():
         a = test.TestAttr.get()
         # CHECK: #python_test.test_attr
         print(a)
 
+        # CHECK: python_test.custom_attributed_op  {
+        # CHECK: #python_test.test_attr
+        # CHECK: }
+        op2 = test.CustomAttributedOp(a)
+        print(f"{op2}")
+
+        # CHECK: #python_test.test_attr
+        print(f"{op2.test_attr}")
+
+        # CHECK: TestAttr(#python_test.test_attr)
+        print(repr(op2.test_attr))
+
         # The following cast must not assert.
         b = test.TestAttr(a)
 
diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index 71778a97d83a4..cb7d7677714fe 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -23,6 +23,10 @@ MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
   return wrap(python_test::TestAttrAttr::get(unwrap(context)));
 }
 
+MlirTypeID mlirPythonTestTestAttributeGetTypeID(void) {
+  return wrap(python_test::TestAttrAttr::getTypeID());
+}
+
 bool mlirTypeIsAPythonTestTestType(MlirType type) {
   return llvm::isa<python_test::TestTypeType>(unwrap(type));
 }
diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
index 5f1ed3a5b2ad6..43f8fdcbfae12 100644
--- a/mlir/test/python/lib/PythonTestCAPI.h
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -23,6 +23,8 @@ mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirPythonTestTestAttributeGet(MlirContext context);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestAttributeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
 
 MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h
index 044381fcd4728..889365e1136b4 100644
--- a/mlir/test/python/lib/PythonTestDialect.h
+++ b/mlir/test/python/lib/PythonTestDialect.h
@@ -16,13 +16,13 @@
 
 #include "PythonTestDialect.h.inc"
 
-#define GET_OP_CLASSES
-#include "PythonTestOps.h.inc"
-
 #define GET_ATTRDEF_CLASSES
 #include "PythonTestAttributes.h.inc"
 
 #define GET_TYPEDEF_CLASSES
 #include "PythonTestTypes.h.inc"
 
+#define GET_OP_CLASSES
+#include "PythonTestOps.h.inc"
+
 #endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f81b851f8759b..a4f538dcb5594 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -44,10 +44,11 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
       py::arg("registry"));
 
   mlir_attribute_subclass(m, "TestAttr",
-                          mlirAttributeIsAPythonTestTestAttribute)
+                          mlirAttributeIsAPythonTestTestAttribute,
+                          mlirPythonTestTestAttributeGetTypeID)
       .def_classmethod(
           "get",
-          [](py::object cls, MlirContext ctx) {
+          [](const py::object &cls, MlirContext ctx) {
             return cls(mlirPythonTestTestAttributeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
@@ -56,7 +57,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
                      mlirPythonTestTestTypeGetTypeID)
       .def_classmethod(
           "get",
-          [](py::object cls, MlirContext ctx) {
+          [](const py::object &cls, MlirContext ctx) {
             return cls(mlirPythonTestTestTypeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 95301985e3fde..5a82c00ae6080 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -58,6 +58,10 @@ def AttributedOp : TestOp<"attributed_op"> {
                    UnitAttr:$unit);
 }
 
+def CustomAttributedOp : TestOp<"custom_attributed_op"> {
+  let arguments = (ins TestAttr:$test_attr);
+}
+
 def AttributesOp : TestOp<"attributes_op"> {
   let arguments = (ins
                    AffineMapArrayAttr:$x_affinemaparr,



More information about the Mlir-commits mailing list