[Mlir-commits] [mlir] 9315645 - [mlir][python] auto attribute casting (#97786)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 5 08:43:55 PDT 2024


Author: Maksim Levental
Date: 2024-07-05T10:43:51-05:00
New Revision: 9315645834ea81cf9550364a4950f289e9706a26

URL: https://github.com/llvm/llvm-project/commit/9315645834ea81cf9550364a4950f289e9706a26
DIFF: https://github.com/llvm/llvm-project/commit/9315645834ea81cf9550364a4950f289e9706a26.diff

LOG: [mlir][python] auto attribute casting (#97786)

Added: 
    

Modified: 
    mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    mlir/test/python/dialects/python_test.py
    mlir/test/python/lib/PythonTestCAPI.cpp
    mlir/test/python/lib/PythonTestCAPI.h
    mlir/test/python/lib/PythonTestDialect.h
    mlir/test/python/lib/PythonTestModule.cpp
    mlir/test/python/python_test_ops.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index ebf50109f72f23..df4b9bf713592d 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -406,21 +406,25 @@ class pure_subclass {
 class mlir_attribute_subclass : public pure_subclass {
 public:
   using IsAFunctionTy = bool (*)(MlirAttribute);
+  using GetTypeIDFunctionTy = MlirTypeID (*)();
 
   /// Subclasses by looking up the super-class dynamically.
   mlir_attribute_subclass(py::handle scope, const char *attrClassName,
-                          IsAFunctionTy isaFunction)
+                          IsAFunctionTy isaFunction,
+                          GetTypeIDFunctionTy getTypeIDFunction = nullptr)
       : mlir_attribute_subclass(
             scope, attrClassName, isaFunction,
             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Attribute")) {}
+                .attr("Attribute"),
+            getTypeIDFunction) {}
 
   /// Subclasses with a provided mlir.ir.Attribute super-class. This must
   /// be used if the subclass is being defined in the same extension module
   /// as the mlir.ir class (otherwise, it will trigger a recursive
   /// initialization).
   mlir_attribute_subclass(py::handle scope, const char *typeClassName,
-                          IsAFunctionTy isaFunction, const py::object &superCls)
+                          IsAFunctionTy isaFunction, const py::object &superCls,
+                          GetTypeIDFunctionTy getTypeIDFunction = nullptr)
       : pure_subclass(scope, typeClassName, superCls) {
     // Casting constructor. Note that it hard, if not impossible, to properly
     // call chain to parent `__init__` in pybind11 due to its special handling
@@ -454,6 +458,20 @@ class mlir_attribute_subclass : public pure_subclass {
         "isinstance",
         [isaFunction](MlirAttribute other) { return isaFunction(other); },
         py::arg("other_attribute"));
+    def("__repr__", [superCls, captureTypeName](py::object self) {
+      return py::repr(superCls(self))
+          .attr("replace")(superCls.attr("__name__"), captureTypeName);
+    });
+    if (getTypeIDFunction) {
+      def_staticmethod("get_static_typeid",
+                       [getTypeIDFunction]() { return getTypeIDFunction(); });
+      py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+          .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+              getTypeIDFunction())(pybind11::cpp_function(
+              [thisClass = thisClass](const py::object &mlirAttribute) {
+                return thisClass(mlirAttribute);
+              }));
+    }
   }
 };
 

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 70927b22d4749c..a76f3f2b5e4583 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -307,11 +307,23 @@ def testOptionalOperandOp():
 # CHECK-LABEL: TEST: testCustomAttribute
 @run
 def testCustomAttribute():
-    with Context() as ctx:
+    with Context() as ctx, Location.unknown():
         a = test.TestAttr.get()
         # CHECK: #python_test.test_attr
         print(a)
 
+        # CHECK: python_test.custom_attributed_op  {
+        # CHECK: #python_test.test_attr
+        # CHECK: }
+        op2 = test.CustomAttributedOp(a)
+        print(f"{op2}")
+
+        # CHECK: #python_test.test_attr
+        print(f"{op2.test_attr}")
+
+        # CHECK: TestAttr(#python_test.test_attr)
+        print(repr(op2.test_attr))
+
         # The following cast must not assert.
         b = test.TestAttr(a)
 

diff  --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index 71778a97d83a41..cb7d7677714fe6 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -23,6 +23,10 @@ MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
   return wrap(python_test::TestAttrAttr::get(unwrap(context)));
 }
 
+MlirTypeID mlirPythonTestTestAttributeGetTypeID(void) {
+  return wrap(python_test::TestAttrAttr::getTypeID());
+}
+
 bool mlirTypeIsAPythonTestTestType(MlirType type) {
   return llvm::isa<python_test::TestTypeType>(unwrap(type));
 }

diff  --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
index 5f1ed3a5b2ad66..43f8fdcbfae125 100644
--- a/mlir/test/python/lib/PythonTestCAPI.h
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -23,6 +23,8 @@ mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirPythonTestTestAttributeGet(MlirContext context);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestAttributeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
 
 MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);

diff  --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h
index 044381fcd4728d..889365e1136b4e 100644
--- a/mlir/test/python/lib/PythonTestDialect.h
+++ b/mlir/test/python/lib/PythonTestDialect.h
@@ -16,13 +16,13 @@
 
 #include "PythonTestDialect.h.inc"
 
-#define GET_OP_CLASSES
-#include "PythonTestOps.h.inc"
-
 #define GET_ATTRDEF_CLASSES
 #include "PythonTestAttributes.h.inc"
 
 #define GET_TYPEDEF_CLASSES
 #include "PythonTestTypes.h.inc"
 
+#define GET_OP_CLASSES
+#include "PythonTestOps.h.inc"
+
 #endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H

diff  --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f81b851f8759bf..a4f538dcb55944 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -44,10 +44,11 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
       py::arg("registry"));
 
   mlir_attribute_subclass(m, "TestAttr",
-                          mlirAttributeIsAPythonTestTestAttribute)
+                          mlirAttributeIsAPythonTestTestAttribute,
+                          mlirPythonTestTestAttributeGetTypeID)
       .def_classmethod(
           "get",
-          [](py::object cls, MlirContext ctx) {
+          [](const py::object &cls, MlirContext ctx) {
             return cls(mlirPythonTestTestAttributeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
@@ -56,7 +57,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
                      mlirPythonTestTestTypeGetTypeID)
       .def_classmethod(
           "get",
-          [](py::object cls, MlirContext ctx) {
+          [](const py::object &cls, MlirContext ctx) {
             return cls(mlirPythonTestTestTypeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());

diff  --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index c0bc18448610a0..6211fb9987c76a 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -58,6 +58,10 @@ def AttributedOp : TestOp<"attributed_op"> {
                    UnitAttr:$unit);
 }
 
+def CustomAttributedOp : TestOp<"custom_attributed_op"> {
+  let arguments = (ins TestAttr:$test_attr);
+}
+
 def AttributesOp : TestOp<"attributes_op"> {
   let arguments = (ins
                    AffineMapArrayAttr:$x_affinemaparr,


        


More information about the Mlir-commits mailing list