[Mlir-commits] [mlir] 89a92fb - [mlir] Rework subclass construction in PybindAdaptors.h

Alex Zinenko llvmlistbot at llvm.org
Wed Jan 19 09:09:12 PST 2022


Author: Alex Zinenko
Date: 2022-01-19T18:09:05+01:00
New Revision: 89a92fb3ba668d0a4ec25c1268c31d9b35ab17e7

URL: https://github.com/llvm/llvm-project/commit/89a92fb3ba668d0a4ec25c1268c31d9b35ab17e7
DIFF: https://github.com/llvm/llvm-project/commit/89a92fb3ba668d0a4ec25c1268c31d9b35ab17e7.diff

LOG: [mlir] Rework subclass construction in PybindAdaptors.h

The constructor function was being defined without indicating its "__init__"
name, which made it interpret it as a regular fuction rather than a
constructor. When overload resolution failed, Pybind would attempt to print the
arguments actually passed to the function, including "self", which is not
initialized since the constructor couldn't be called. This would result in
"__repr__" being called with "self" referencing an uninitialized MLIR C API
object, which in turn would cause undefined behavior when attempting to print
in C++. Even if the correct name is provided, the mechanism used by
PybindAdaptors.h to bind constructors directly as "__init__" functions taking
"self" is deprecated by Pybind. The new mechanism does not seem to have access
to a fully-constructed "self" object (i.e., the constructor in C++ takes a
`pybind11::detail::value_and_holder` that cannot be forwarded back to Python).

Instead, redefine "__new__" to perform the required checks (there are no
additional initialization needed for attributes and types as they are all
wrappers around a C++ pointer). "__new__" can call its equivalent on a
superclass without needing "self".

Bump pybind11 dependency to 3.8.0, which is the first version that allows one
to redefine "__new__".

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D117646

Added: 
    

Modified: 
    mlir/cmake/modules/MLIRDetectPythonEnv.cmake
    mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    mlir/python/mlir/dialects/python_test.py
    mlir/python/requirements.txt
    mlir/test/python/CMakeLists.txt
    mlir/test/python/dialects/python_test.py
    mlir/test/python/lib/PythonTestCAPI.cpp
    mlir/test/python/lib/PythonTestCAPI.h
    mlir/test/python/lib/PythonTestDialect.cpp
    mlir/test/python/lib/PythonTestDialect.h
    mlir/test/python/lib/PythonTestModule.cpp
    mlir/test/python/python_test_ops.td

Removed: 
    


################################################################################
diff  --git a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
index 4739fddd93a4b..9b36406edfcde 100644
--- a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
+++ b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
@@ -32,7 +32,7 @@ macro(mlir_configure_python_dev_packages)
   message(STATUS "Found python libraries: ${Python3_LIBRARIES}")
   message(STATUS "Found numpy v${Python3_NumPy_VERSION}: ${Python3_NumPy_INCLUDE_DIRS}")
   mlir_detect_pybind11_install()
-  find_package(pybind11 2.6 CONFIG REQUIRED)
+  find_package(pybind11 2.8 CONFIG REQUIRED)
   message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIR}")
   message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
                  "suffix = '${PYTHON_MODULE_SUFFIX}', "

diff  --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 0340e9cc4b87f..73cc7e4412fbd 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -314,31 +314,34 @@ class mlir_attribute_subclass : public pure_subclass {
   /// as the mlir.ir class (otherwise, it will trigger a recursive
   /// initialization).
   mlir_attribute_subclass(py::handle scope, const char *typeClassName,
-                          IsAFunctionTy isaFunction,
-                          const py::object &superClass)
-      : pure_subclass(scope, typeClassName, superClass) {
-    // Casting constructor. Note that defining an __init__ method is special
-    // and not yet generalized on pure_subclass (it requires a somewhat
-    // 
diff erent cpp_function and other requirements on chaining to super
-    // __init__ make it more awkward to do generally).
+                          IsAFunctionTy isaFunction, const py::object &superCls)
+      : pure_subclass(scope, typeClassName, superCls) {
+    // Casting constructor. Note that it hard, if not impossible, to properly
+    // call chain to parent `__init__` in pybind11 due to its special handling
+    // for init functions that don't have a fully constructed self-reference,
+    // which makes it impossible to forward it to `__init__` of a superclass.
+    // Instead, provide a custom `__new__` and call that of a superclass, which
+    // eventually calls `__init__` of the superclass. Since attribute subclasses
+    // have no additional members, we can just return the instance thus created
+    // without amending it.
     std::string captureTypeName(
         typeClassName); // As string in case if typeClassName is not static.
-    py::cpp_function initCf(
-        [superClass, isaFunction, captureTypeName](py::object self,
-                                                   py::object otherType) {
-          MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherType);
+    py::cpp_function newCf(
+        [superCls, isaFunction, captureTypeName](py::object cls,
+                                                 py::object otherAttribute) {
+          MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute);
           if (!isaFunction(rawAttribute)) {
-            auto origRepr = py::repr(otherType).cast<std::string>();
+            auto origRepr = py::repr(otherAttribute).cast<std::string>();
             throw std::invalid_argument(
                 (llvm::Twine("Cannot cast attribute to ") + captureTypeName +
                  " (from " + origRepr + ")")
                     .str());
           }
-          superClass.attr("__init__")(self, otherType);
+          py::object self = superCls.attr("__new__")(cls, otherAttribute);
+          return self;
         },
-        py::arg("cast_from_type"), py::is_method(py::none()),
-        "Casts the passed type to this specific sub-type.");
-    thisClass.attr("__init__") = initCf;
+        py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr"));
+    thisClass.attr("__new__") = newCf;
 
     // 'isinstance' method.
     def_staticmethod(
@@ -366,17 +369,21 @@ class mlir_type_subclass : public pure_subclass {
   /// as the mlir.ir class (otherwise, it will trigger a recursive
   /// initialization).
   mlir_type_subclass(py::handle scope, const char *typeClassName,
-                     IsAFunctionTy isaFunction, const py::object &superClass)
-      : pure_subclass(scope, typeClassName, superClass) {
-    // Casting constructor. Note that defining an __init__ method is special
-    // and not yet generalized on pure_subclass (it requires a somewhat
-    // 
diff erent cpp_function and other requirements on chaining to super
-    // __init__ make it more awkward to do generally).
+                     IsAFunctionTy isaFunction, const py::object &superCls)
+      : pure_subclass(scope, typeClassName, superCls) {
+    // Casting constructor. Note that it hard, if not impossible, to properly
+    // call chain to parent `__init__` in pybind11 due to its special handling
+    // for init functions that don't have a fully constructed self-reference,
+    // which makes it impossible to forward it to `__init__` of a superclass.
+    // Instead, provide a custom `__new__` and call that of a superclass, which
+    // eventually calls `__init__` of the superclass. Since attribute subclasses
+    // have no additional members, we can just return the instance thus created
+    // without amending it.
     std::string captureTypeName(
         typeClassName); // As string in case if typeClassName is not static.
-    py::cpp_function initCf(
-        [superClass, isaFunction, captureTypeName](py::object self,
-                                                   py::object otherType) {
+    py::cpp_function newCf(
+        [superCls, isaFunction, captureTypeName](py::object cls,
+                                                 py::object otherType) {
           MlirType rawType = py::cast<MlirType>(otherType);
           if (!isaFunction(rawType)) {
             auto origRepr = py::repr(otherType).cast<std::string>();
@@ -385,11 +392,11 @@ class mlir_type_subclass : public pure_subclass {
                                          origRepr + ")")
                                             .str());
           }
-          superClass.attr("__init__")(self, otherType);
+          py::object self = superCls.attr("__new__")(cls, otherType);
+          return self;
         },
-        py::arg("cast_from_type"), py::is_method(py::none()),
-        "Casts the passed type to this specific sub-type.");
-    thisClass.attr("__init__") = initCf;
+        py::name("__new__"), py::arg("cls"), py::arg("cast_from_type"));
+    thisClass.attr("__new__") = newCf;
 
     // 'isinstance' method.
     def_staticmethod(

diff  --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 82c01d5a091c7..9f560c205ef4e 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,7 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._python_test_ops_gen import *
-
+from .._mlir_libs._mlirPythonTest import TestAttr, TestType
 
 def register_python_test_dialect(context, load=True):
   from .._mlir_libs import _mlirPythonTest

diff  --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index f76dcf6766ae0..0cc86af2c9cfb 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,4 +1,3 @@
 numpy
-# Version 2.7.0 excluded: https://github.com/pybind/pybind11/issues/3136
-pybind11>=2.6.0,!=2.7.0
+pybind11>=2.8.0
 PyYAML

diff  --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt
index c8cb474760e2c..1db957c86819d 100644
--- a/mlir/test/python/CMakeLists.txt
+++ b/mlir/test/python/CMakeLists.txt
@@ -3,6 +3,10 @@ mlir_tablegen(lib/PythonTestDialect.h.inc -gen-dialect-decls)
 mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs)
 mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls)
 mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs)
+mlir_tablegen(lib/PythonTestAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(lib/PythonTestAttributes.cpp.inc -gen-attrdef-defs)
+mlir_tablegen(lib/PythonTestTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(lib/PythonTestTypes.cpp.inc -gen-typedef-defs)
 add_public_tablegen_target(MLIRPythonTestIncGen)
 
 add_subdirectory(lib)

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index f9da91fba4cdf..9c096577f9360 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -225,3 +225,62 @@ def testOptionalOperandOp():
       op2 = test.OptionalOperandOp(op1)
       # CHECK: op2.input is None: False
       print(f"op2.input is None: {op2.input is None}")
+
+
+# CHECK-LABEL: TEST: testCustomAttribute
+ at run
+def testCustomAttribute():
+  with Context() as ctx:
+    test.register_python_test_dialect(ctx)
+    a = test.TestAttr.get()
+    # CHECK: #python_test.test_attr
+    print(a)
+
+    # The following cast must not assert.
+    b = test.TestAttr(a)
+
+    unit = UnitAttr.get()
+    try:
+      test.TestAttr(unit)
+    except ValueError as e:
+      assert "Cannot cast attribute to TestAttr" in str(e)
+    else:
+      raise
+
+    # The following must trigger a TypeError from pybind (therefore, not
+    # checking its message) and must not crash.
+    try:
+      test.TestAttr(42, 56)
+    except TypeError:
+      pass
+    else:
+      raise
+
+
+ at run
+def testCustomType():
+  with Context() as ctx:
+    test.register_python_test_dialect(ctx)
+    a = test.TestType.get()
+    # CHECK: !python_test.test_type
+    print(a)
+
+    # The following cast must not assert.
+    b = test.TestType(a)
+
+    i8 = IntegerType.get_signless(8)
+    try:
+      test.TestType(i8)
+    except ValueError as e:
+      assert "Cannot cast type to TestType" in str(e)
+    else:
+      raise
+
+    # The following must trigger a TypeError from pybind (therefore, not
+    # checking its message) and must not crash.
+    try:
+      test.TestType(42, 56)
+    except TypeError:
+      pass
+    else:
+      raise

diff  --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index 474476e741985..e52588aa7dc11 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -9,6 +9,23 @@
 #include "PythonTestCAPI.h"
 #include "PythonTestDialect.h"
 #include "mlir/CAPI/Registration.h"
+#include "mlir/CAPI/Wrap.h"
 
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test,
                                       python_test::PythonTestDialect)
+
+bool mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr) {
+  return unwrap(attr).isa<python_test::TestAttrAttr>();
+}
+
+MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
+  return wrap(python_test::TestAttrAttr::get(unwrap(context)));
+}
+
+bool mlirTypeIsAPythonTestTestType(MlirType type) {
+  return unwrap(type).isa<python_test::TestTypeType>();
+}
+
+MlirType mlirPythonTestTestTypeGet(MlirContext context) {
+  return wrap(python_test::TestTypeType::get(unwrap(context)));
+}

diff  --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
index 627ce3fe9a151..dd491028dbec1 100644
--- a/mlir/test/python/lib/PythonTestCAPI.h
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -17,6 +17,16 @@ extern "C" {
 
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test);
 
+MLIR_CAPI_EXPORTED bool
+mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirPythonTestTestAttributeGet(MlirContext context);
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/test/python/lib/PythonTestDialect.cpp b/mlir/test/python/lib/PythonTestDialect.cpp
index b70c0336b2f64..a0ff31504c691 100644
--- a/mlir/test/python/lib/PythonTestDialect.cpp
+++ b/mlir/test/python/lib/PythonTestDialect.cpp
@@ -9,9 +9,16 @@
 #include "PythonTestDialect.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 #include "PythonTestDialect.cpp.inc"
 
+#define GET_ATTRDEF_CLASSES
+#include "PythonTestAttributes.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "PythonTestTypes.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "PythonTestOps.cpp.inc"
 
@@ -21,5 +28,14 @@ void PythonTestDialect::initialize() {
 #define GET_OP_LIST
 #include "PythonTestOps.cpp.inc"
       >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "PythonTestAttributes.cpp.inc"
+      >();
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "PythonTestTypes.cpp.inc"
+      >();
 }
+
 } // namespace python_test

diff  --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h
index e25d00ceec980..e91cba106e56a 100644
--- a/mlir/test/python/lib/PythonTestDialect.h
+++ b/mlir/test/python/lib/PythonTestDialect.h
@@ -18,4 +18,10 @@
 #define GET_OP_CLASSES
 #include "PythonTestOps.h.inc"
 
+#define GET_ATTRDEF_CLASSES
+#include "PythonTestAttributes.h.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "PythonTestTypes.h.inc"
+
 #endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H

diff  --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index 4232a86518636..6fb9b24e69ae1 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 
 namespace py = pybind11;
+using namespace mlir::python::adaptors;
 
 PYBIND11_MODULE(_mlirPythonTest, m) {
   m.def(
@@ -23,4 +24,20 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
         }
       },
       py::arg("context"), py::arg("load") = true);
+
+  mlir_attribute_subclass(m, "TestAttr",
+                          mlirAttributeIsAPythonTestTestAttribute)
+      .def_classmethod(
+          "get",
+          [](py::object cls, MlirContext ctx) {
+            return cls(mlirPythonTestTestAttributeGet(ctx));
+          },
+          py::arg("cls"), py::arg("context") = py::none());
+  mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType)
+      .def_classmethod(
+          "get",
+          [](py::object cls, MlirContext ctx) {
+            return cls(mlirPythonTestTestTypeGet(ctx));
+          },
+          py::arg("cls"), py::arg("context") = py::none());
 }

diff  --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 6ee71dbf8b123..a274ffae69831 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -17,9 +17,36 @@ def Python_Test_Dialect : Dialect {
   let name = "python_test";
   let cppNamespace = "python_test";
 }
+
+class TestType<string name, string typeMnemonic>
+    : TypeDef<Python_Test_Dialect, name> {
+  let mnemonic = typeMnemonic;
+}
+
+class TestAttr<string name, string attrMnemonic>
+    : AttrDef<Python_Test_Dialect, name> {
+  let mnemonic = attrMnemonic;
+}
+
 class TestOp<string mnemonic, list<OpTrait> traits = []>
     : Op<Python_Test_Dialect, mnemonic, traits>;
 
+//===----------------------------------------------------------------------===//
+// Type definitions.
+//===----------------------------------------------------------------------===//
+
+def TestType : TestType<"TestType", "test_type">;
+
+//===----------------------------------------------------------------------===//
+// Attribute definitions.
+//===----------------------------------------------------------------------===//
+
+def TestAttr : TestAttr<"TestAttr", "test_attr">;
+
+//===----------------------------------------------------------------------===//
+// Operation definitions.
+//===----------------------------------------------------------------------===//
+
 def AttributedOp : TestOp<"attributed_op"> {
   let arguments = (ins I32Attr:$mandatory_i32,
                    OptionalAttr<I32Attr>:$optional_i32,


        


More information about the Mlir-commits mailing list