[Mlir-commits] [mlir] [mlir][python] auto attribute casting (PR #97786)
Maksim Levental
llvmlistbot at llvm.org
Thu Jul 4 20:42:16 PDT 2024
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/97786
This PR implements auto attribute casting for downstream attributes just like we have for downstream types.
Use case: https://github.com/openxla/shardy
cc @bartchr808
>From 3e5d81020dcb035f8d6dca99569b95e626cbdf0d Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 4 Jul 2024 22:40:14 -0500
Subject: [PATCH] [mlir][python] auto attribute casting
---
.../mlir/Bindings/Python/PybindAdaptors.h | 29 +++++++++++++++++--
mlir/test/python/dialects/python_test.py | 14 ++++++++-
mlir/test/python/lib/PythonTestCAPI.cpp | 4 +++
mlir/test/python/lib/PythonTestCAPI.h | 2 ++
mlir/test/python/lib/PythonTestDialect.h | 6 ++--
mlir/test/python/lib/PythonTestModule.cpp | 7 +++--
mlir/test/python/python_test_ops.td | 4 +++
7 files changed, 56 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index ebf50109f72f2..67cc48277efcb 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -406,21 +406,25 @@ class pure_subclass {
class mlir_attribute_subclass : public pure_subclass {
public:
using IsAFunctionTy = bool (*)(MlirAttribute);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
/// Subclasses by looking up the super-class dynamically.
mlir_attribute_subclass(py::handle scope, const char *attrClassName,
- IsAFunctionTy isaFunction)
+ IsAFunctionTy isaFunction,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: mlir_attribute_subclass(
scope, attrClassName, isaFunction,
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Attribute")) {}
+ .attr("Attribute"),
+ getTypeIDFunction) {}
/// Subclasses with a provided mlir.ir.Attribute super-class. This must
/// be used if the subclass is being defined in the same extension module
/// as the mlir.ir class (otherwise, it will trigger a recursive
/// initialization).
mlir_attribute_subclass(py::handle scope, const char *typeClassName,
- IsAFunctionTy isaFunction, const py::object &superCls)
+ IsAFunctionTy isaFunction, const py::object &superCls,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: pure_subclass(scope, typeClassName, superCls) {
// Casting constructor. Note that it hard, if not impossible, to properly
// call chain to parent `__init__` in pybind11 due to its special handling
@@ -454,6 +458,25 @@ class mlir_attribute_subclass : public pure_subclass {
"isinstance",
[isaFunction](MlirAttribute other) { return isaFunction(other); },
py::arg("other_attribute"));
+ def("__repr__", [superCls, captureTypeName](py::object self) {
+ return py::repr(superCls(self))
+ .attr("replace")(superCls.attr("__name__"), captureTypeName);
+ });
+ if (getTypeIDFunction) {
+ // 'get_static_typeid' method.
+ // This is modeled as a static method instead of a static property because
+ // `def_property_readonly_static` is not available in `pure_subclass` and
+ // we do not want to introduce the complexity that pybind uses to
+ // implement it.
+ def_staticmethod("get_static_typeid",
+ [getTypeIDFunction]() { return getTypeIDFunction(); });
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ getTypeIDFunction())(pybind11::cpp_function(
+ [thisClass = thisClass](const py::object &mlirAttribute) {
+ return thisClass(mlirAttribute);
+ }));
+ }
}
};
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 70927b22d4749..a76f3f2b5e458 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -307,11 +307,23 @@ def testOptionalOperandOp():
# CHECK-LABEL: TEST: testCustomAttribute
@run
def testCustomAttribute():
- with Context() as ctx:
+ with Context() as ctx, Location.unknown():
a = test.TestAttr.get()
# CHECK: #python_test.test_attr
print(a)
+ # CHECK: python_test.custom_attributed_op {
+ # CHECK: #python_test.test_attr
+ # CHECK: }
+ op2 = test.CustomAttributedOp(a)
+ print(f"{op2}")
+
+ # CHECK: #python_test.test_attr
+ print(f"{op2.test_attr}")
+
+ # CHECK: TestAttr(#python_test.test_attr)
+ print(repr(op2.test_attr))
+
# The following cast must not assert.
b = test.TestAttr(a)
diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index 71778a97d83a4..cb7d7677714fe 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -23,6 +23,10 @@ MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
return wrap(python_test::TestAttrAttr::get(unwrap(context)));
}
+MlirTypeID mlirPythonTestTestAttributeGetTypeID(void) {
+ return wrap(python_test::TestAttrAttr::getTypeID());
+}
+
bool mlirTypeIsAPythonTestTestType(MlirType type) {
return llvm::isa<python_test::TestTypeType>(unwrap(type));
}
diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
index 5f1ed3a5b2ad6..43f8fdcbfae12 100644
--- a/mlir/test/python/lib/PythonTestCAPI.h
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -23,6 +23,8 @@ mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute
mlirPythonTestTestAttributeGet(MlirContext context);
+MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestAttributeGetTypeID(void);
+
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h
index 044381fcd4728..889365e1136b4 100644
--- a/mlir/test/python/lib/PythonTestDialect.h
+++ b/mlir/test/python/lib/PythonTestDialect.h
@@ -16,13 +16,13 @@
#include "PythonTestDialect.h.inc"
-#define GET_OP_CLASSES
-#include "PythonTestOps.h.inc"
-
#define GET_ATTRDEF_CLASSES
#include "PythonTestAttributes.h.inc"
#define GET_TYPEDEF_CLASSES
#include "PythonTestTypes.h.inc"
+#define GET_OP_CLASSES
+#include "PythonTestOps.h.inc"
+
#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f81b851f8759b..a4f538dcb5594 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -44,10 +44,11 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
py::arg("registry"));
mlir_attribute_subclass(m, "TestAttr",
- mlirAttributeIsAPythonTestTestAttribute)
+ mlirAttributeIsAPythonTestTestAttribute,
+ mlirPythonTestTestAttributeGetTypeID)
.def_classmethod(
"get",
- [](py::object cls, MlirContext ctx) {
+ [](const py::object &cls, MlirContext ctx) {
return cls(mlirPythonTestTestAttributeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
@@ -56,7 +57,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
mlirPythonTestTestTypeGetTypeID)
.def_classmethod(
"get",
- [](py::object cls, MlirContext ctx) {
+ [](const py::object &cls, MlirContext ctx) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 95301985e3fde..5a82c00ae6080 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -58,6 +58,10 @@ def AttributedOp : TestOp<"attributed_op"> {
UnitAttr:$unit);
}
+def CustomAttributedOp : TestOp<"custom_attributed_op"> {
+ let arguments = (ins TestAttr:$test_attr);
+}
+
def AttributesOp : TestOp<"attributes_op"> {
let arguments = (ins
AffineMapArrayAttr:$x_affinemaparr,
More information about the Mlir-commits
mailing list