[Mlir-commits] [mlir] 69cc3cf - [MLIR][python bindings] implement `PyValue` subclassing to enable operator overloading
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 14 12:27:15 PDT 2023
Author: max
Date: 2023-04-14T14:25:06-05:00
New Revision: 69cc3cfb21b7b962ae77c2438bfb2d1e21f5f77e
URL: https://github.com/llvm/llvm-project/commit/69cc3cfb21b7b962ae77c2438bfb2d1e21f5f77e
DIFF: https://github.com/llvm/llvm-project/commit/69cc3cfb21b7b962ae77c2438bfb2d1e21f5f77e.diff
LOG: [MLIR][python bindings] implement `PyValue` subclassing to enable operator overloading
Differential Revision: https://reviews.llvm.org/D147758
Added:
Modified:
mlir/include/mlir/Bindings/Python/PybindAdaptors.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/python/mlir/dialects/python_test.py
mlir/test/python/dialects/python_test.py
mlir/test/python/lib/PythonTestCAPI.cpp
mlir/test/python/lib/PythonTestCAPI.h
mlir/test/python/lib/PythonTestModule.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 98d80f0101527..bec3fc76e39d2 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -453,6 +453,62 @@ class mlir_type_subclass : public pure_subclass {
}
};
+/// Creates a custom subclass of mlir.ir.Value, implementing a casting
+/// constructor and type checking methods.
+class mlir_value_subclass : public pure_subclass {
+public:
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ /// Subclasses by looking up the super-class dynamically.
+ mlir_value_subclass(py::handle scope, const char *valueClassName,
+ IsAFunctionTy isaFunction)
+ : mlir_value_subclass(
+ scope, valueClassName, isaFunction,
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) {
+ }
+
+ /// Subclasses with a provided mlir.ir.Value super-class. This must
+ /// be used if the subclass is being defined in the same extension module
+ /// as the mlir.ir class (otherwise, it will trigger a recursive
+ /// initialization).
+ mlir_value_subclass(py::handle scope, const char *valueClassName,
+ IsAFunctionTy isaFunction, const py::object &superCls)
+ : pure_subclass(scope, valueClassName, superCls) {
+ // Casting constructor. Note that it hard, if not impossible, to properly
+ // call chain to parent `__init__` in pybind11 due to its special handling
+ // for init functions that don't have a fully constructed self-reference,
+ // which makes it impossible to forward it to `__init__` of a superclass.
+ // Instead, provide a custom `__new__` and call that of a superclass, which
+ // eventually calls `__init__` of the superclass. Since attribute subclasses
+ // have no additional members, we can just return the instance thus created
+ // without amending it.
+ std::string captureValueName(
+ valueClassName); // As string in case if valueClassName is not static.
+ py::cpp_function newCf(
+ [superCls, isaFunction, captureValueName](py::object cls,
+ py::object otherValue) {
+ MlirValue rawValue = py::cast<MlirValue>(otherValue);
+ if (!isaFunction(rawValue)) {
+ auto origRepr = py::repr(otherValue).cast<std::string>();
+ throw std::invalid_argument((llvm::Twine("Cannot cast value to ") +
+ captureValueName + " (from " +
+ origRepr + ")")
+ .str());
+ }
+ py::object self = superCls.attr("__new__")(cls, otherValue);
+ return self;
+ },
+ py::name("__new__"), py::arg("cls"), py::arg("cast_from_value"));
+ thisClass.attr("__new__") = newCf;
+
+ // 'isinstance' method.
+ def_staticmethod(
+ "isinstance",
+ [isaFunction](MlirValue other) { return isaFunction(other); },
+ py::arg("other_value"));
+ }
+};
+
} // namespace adaptors
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f2d3780a4aa3c..f3fd386779373 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3260,6 +3260,7 @@ void mlir::python::populateIRCore(py::module &m) {
// Mapping of Value.
//----------------------------------------------------------------------------
py::class_<PyValue>(m, "Value", py::module_local())
+ .def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"))
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
.def_property_readonly(
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 9f560c205ef4e..5d42ddc47a242 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import TestAttr, TestType
+from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue
def register_python_test_dialect(context, load=True):
from .._mlir_libs import _mlirPythonTest
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index c73fce23d3c49..d826540bec1da 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -2,6 +2,7 @@
from mlir.ir import *
import mlir.dialects.python_test as test
+import mlir.dialects.tensor as tensor
def run(f):
print("\nTEST:", f.__name__)
@@ -302,3 +303,30 @@ def testCustomType():
pass
else:
raise
+
+
+ at run
+# CHECK-LABEL: TEST: testTensorValue
+def testTensorValue():
+ with Context() as ctx, Location.unknown():
+ test.register_python_test_dialect(ctx)
+
+ i8 = IntegerType.get_signless(8)
+
+ class Tensor(test.TestTensorValue):
+ def __str__(self):
+ return super().__str__().replace("Value", "Tensor")
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ t = tensor.EmptyOp([10, 10], i8).result
+
+ # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
+ print(Value(t))
+
+ tt = Tensor(t)
+ # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
+ print(tt)
+
+ # CHECK: False
+ print(tt.is_null())
diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index e52588aa7dc11..280cfa0b1738d 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -8,6 +8,7 @@
#include "PythonTestCAPI.h"
#include "PythonTestDialect.h"
+#include "mlir-c/BuiltinTypes.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Wrap.h"
@@ -29,3 +30,7 @@ bool mlirTypeIsAPythonTestTestType(MlirType type) {
MlirType mlirPythonTestTestTypeGet(MlirContext context) {
return wrap(python_test::TestTypeType::get(unwrap(context)));
}
+
+bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
+ return mlirTypeIsATensor(wrap(unwrap(value).getType()));
+}
diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
index f72a3c124b3c0..90c5d4a383f95 100644
--- a/mlir/test/python/lib/PythonTestCAPI.h
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
+MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index 6fb9b24e69ae1..f17f0821599c5 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -40,4 +40,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
+ mlir_value_subclass(m, "TestTensorValue",
+ mlirTypeIsAPythonTestTestTensorValue)
+ .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
}
More information about the Mlir-commits
mailing list