[Mlir-commits] [mlir] 69cc3cf - [MLIR][python bindings] implement `PyValue` subclassing to enable operator overloading

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 14 12:27:15 PDT 2023


Author: max
Date: 2023-04-14T14:25:06-05:00
New Revision: 69cc3cfb21b7b962ae77c2438bfb2d1e21f5f77e

URL: https://github.com/llvm/llvm-project/commit/69cc3cfb21b7b962ae77c2438bfb2d1e21f5f77e
DIFF: https://github.com/llvm/llvm-project/commit/69cc3cfb21b7b962ae77c2438bfb2d1e21f5f77e.diff

LOG: [MLIR][python bindings] implement `PyValue` subclassing to enable operator overloading

Differential Revision: https://reviews.llvm.org/D147758

Added: 
    

Modified: 
    mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/python/mlir/dialects/python_test.py
    mlir/test/python/dialects/python_test.py
    mlir/test/python/lib/PythonTestCAPI.cpp
    mlir/test/python/lib/PythonTestCAPI.h
    mlir/test/python/lib/PythonTestModule.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 98d80f0101527..bec3fc76e39d2 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -453,6 +453,62 @@ class mlir_type_subclass : public pure_subclass {
   }
 };
 
+/// Creates a custom subclass of mlir.ir.Value, implementing a casting
+/// constructor and type checking methods.
+class mlir_value_subclass : public pure_subclass {
+public:
+  using IsAFunctionTy = bool (*)(MlirValue);
+
+  /// Subclasses by looking up the super-class dynamically.
+  mlir_value_subclass(py::handle scope, const char *valueClassName,
+                      IsAFunctionTy isaFunction)
+      : mlir_value_subclass(
+            scope, valueClassName, isaFunction,
+            py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) {
+  }
+
+  /// Subclasses with a provided mlir.ir.Value super-class. This must
+  /// be used if the subclass is being defined in the same extension module
+  /// as the mlir.ir class (otherwise, it will trigger a recursive
+  /// initialization).
+  mlir_value_subclass(py::handle scope, const char *valueClassName,
+                      IsAFunctionTy isaFunction, const py::object &superCls)
+      : pure_subclass(scope, valueClassName, superCls) {
+    // Casting constructor. Note that it hard, if not impossible, to properly
+    // call chain to parent `__init__` in pybind11 due to its special handling
+    // for init functions that don't have a fully constructed self-reference,
+    // which makes it impossible to forward it to `__init__` of a superclass.
+    // Instead, provide a custom `__new__` and call that of a superclass, which
+    // eventually calls `__init__` of the superclass. Since attribute subclasses
+    // have no additional members, we can just return the instance thus created
+    // without amending it.
+    std::string captureValueName(
+        valueClassName); // As string in case if valueClassName is not static.
+    py::cpp_function newCf(
+        [superCls, isaFunction, captureValueName](py::object cls,
+                                                  py::object otherValue) {
+          MlirValue rawValue = py::cast<MlirValue>(otherValue);
+          if (!isaFunction(rawValue)) {
+            auto origRepr = py::repr(otherValue).cast<std::string>();
+            throw std::invalid_argument((llvm::Twine("Cannot cast value to ") +
+                                         captureValueName + " (from " +
+                                         origRepr + ")")
+                                            .str());
+          }
+          py::object self = superCls.attr("__new__")(cls, otherValue);
+          return self;
+        },
+        py::name("__new__"), py::arg("cls"), py::arg("cast_from_value"));
+    thisClass.attr("__new__") = newCf;
+
+    // 'isinstance' method.
+    def_staticmethod(
+        "isinstance",
+        [isaFunction](MlirValue other) { return isaFunction(other); },
+        py::arg("other_value"));
+  }
+};
+
 } // namespace adaptors
 } // namespace python
 } // namespace mlir

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f2d3780a4aa3c..f3fd386779373 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3260,6 +3260,7 @@ void mlir::python::populateIRCore(py::module &m) {
   // Mapping of Value.
   //----------------------------------------------------------------------------
   py::class_<PyValue>(m, "Value", py::module_local())
+      .def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"))
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
       .def_property_readonly(

diff  --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 9f560c205ef4e..5d42ddc47a242 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,7 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import TestAttr, TestType
+from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue
 
 def register_python_test_dialect(context, load=True):
   from .._mlir_libs import _mlirPythonTest

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index c73fce23d3c49..d826540bec1da 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -2,6 +2,7 @@
 
 from mlir.ir import *
 import mlir.dialects.python_test as test
+import mlir.dialects.tensor as tensor
 
 def run(f):
   print("\nTEST:", f.__name__)
@@ -302,3 +303,30 @@ def testCustomType():
       pass
     else:
       raise
+
+
+ at run
+# CHECK-LABEL: TEST: testTensorValue
+def testTensorValue():
+  with Context() as ctx, Location.unknown():
+    test.register_python_test_dialect(ctx)
+
+    i8 = IntegerType.get_signless(8)
+
+    class Tensor(test.TestTensorValue):
+      def __str__(self):
+        return super().__str__().replace("Value", "Tensor")
+
+    module = Module.create()
+    with InsertionPoint(module.body):
+      t = tensor.EmptyOp([10, 10], i8).result
+
+      # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
+      print(Value(t))
+
+      tt = Tensor(t)
+      # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
+      print(tt)
+
+      # CHECK: False
+      print(tt.is_null())

diff  --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index e52588aa7dc11..280cfa0b1738d 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -8,6 +8,7 @@
 
 #include "PythonTestCAPI.h"
 #include "PythonTestDialect.h"
+#include "mlir-c/BuiltinTypes.h"
 #include "mlir/CAPI/Registration.h"
 #include "mlir/CAPI/Wrap.h"
 
@@ -29,3 +30,7 @@ bool mlirTypeIsAPythonTestTestType(MlirType type) {
 MlirType mlirPythonTestTestTypeGet(MlirContext context) {
   return wrap(python_test::TestTypeType::get(unwrap(context)));
 }
+
+bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
+  return mlirTypeIsATensor(wrap(unwrap(value).getType()));
+}

diff  --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
index f72a3c124b3c0..90c5d4a383f95 100644
--- a/mlir/test/python/lib/PythonTestCAPI.h
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
 
 MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
 
+MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index 6fb9b24e69ae1..f17f0821599c5 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -40,4 +40,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
             return cls(mlirPythonTestTestTypeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
+  mlir_value_subclass(m, "TestTensorValue",
+                      mlirTypeIsAPythonTestTestTensorValue)
+      .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
 }


        


More information about the Mlir-commits mailing list