[Mlir-commits] [mlir] 7c85086 - [mlir][python] value casting (#69644)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 7 08:49:46 PST 2023
Author: Maksim Levental
Date: 2023-11-07T10:49:41-06:00
New Revision: 7c850867b9ef4427375da6d83c34d0b9c944fcb8
URL: https://github.com/llvm/llvm-project/commit/7c850867b9ef4427375da6d83c34d0b9c944fcb8
DIFF: https://github.com/llvm/llvm-project/commit/7c850867b9ef4427375da6d83c34d0b9c944fcb8.diff
LOG: [mlir][python] value casting (#69644)
This PR adds "value casting", i.e., a mechanism to wrap `ir.Value` in a
proxy class that overloads dunders such as `__add__`, `__sub__`, and
`__mul__` for fun and great profit.
This is thematically similar to
https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb
and
https://github.com/llvm/llvm-project/commit/9566ee280607d91fa2e5eca730a6765ac84dfd0f.
The example in the test demonstrates the value of the feature (no pun
intended):
```python
@register_value_caster(F16Type.static_typeid)
@register_value_caster(F32Type.static_typeid)
@register_value_caster(F64Type.static_typeid)
@register_value_caster(IntegerType.static_typeid)
class ArithValue(Value):
__add__ = partialmethod(_binary_op, op="add")
__sub__ = partialmethod(_binary_op, op="sub")
__mul__ = partialmethod(_binary_op, op="mul")
a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
b = a + a
# CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
print(b)
a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
b = a - a
# CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
print(b)
a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
print(b)
```
**EDIT**: this now goes through the bindings and thus supports automatic
casting of `OpResult` (including as an element of `OpResultList`),
`BlockArgument` (including as an element of `BlockArgumentList`), as
well as `Value`.
Added:
Modified:
mlir/include/mlir-c/Bindings/Python/Interop.h
mlir/include/mlir/Bindings/Python/PybindAdaptors.h
mlir/lib/Bindings/Python/Globals.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/MainModule.cpp
mlir/lib/Bindings/Python/PybindUtils.h
mlir/python/mlir/dialects/_ods_common.py
mlir/python/mlir/ir.py
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/test/python/dialects/arith_dialect.py
mlir/test/python/dialects/python_test.py
mlir/test/python/ir/value.py
mlir/test/python/lib/PythonTestModule.cpp
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index f79c10cb9383829..0a36e97c2ae6831 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -118,13 +118,28 @@
/** Attribute on main C extension module (_mlir) that corresponds to the
* type caster registration binding. The signature of the function is:
- * def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
- * bool replace)
- * where replace indicates the typeCaster should replace any existing registered
- * type casters (such as those for upstream ConcreteTypes).
+ * def register_type_caster(MlirTypeID mlirTypeID, *, bool replace)
+ * which then takes a typeCaster (register_type_caster is meant to be used as a
+ * decorator from python), and where replace indicates the typeCaster should
+ * replace any existing registered type casters (such as those for upstream
+ * ConcreteTypes). The interface of the typeCaster is: def type_caster(ir.Type)
+ * -> SubClassTypeT where SubClassTypeT indicates the result should be a
+ * subclass (inherit from) ir.Type.
*/
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
+/** Attribute on main C extension module (_mlir) that corresponds to the
+ * value caster registration binding. The signature of the function is:
+ * def register_value_caster(MlirTypeID mlirTypeID, *, bool replace)
+ * which then takes a valueCaster (register_value_caster is meant to be used as
+ * a decorator, from python), and where replace indicates the valueCaster should
+ * replace any existing registered value casters. The interface of the
+ * valueCaster is: def value_caster(ir.Value) -> SubClassValueT where
+ * SubClassValueT indicates the result should be a subclass (inherit from)
+ * ir.Value.
+ */
+#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster"
+
/// Gets a void* from a wrapped struct. Needed because const cast is
diff erent
/// between C/C++.
#ifdef __cplusplus
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 49680c8b79b135e..5e0e56fc00a6736 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -234,6 +234,7 @@ struct type_caster<MlirValue> {
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Value")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
.release();
};
};
@@ -496,11 +497,10 @@ class mlir_type_subclass : public pure_subclass {
if (getTypeIDFunction) {
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
- getTypeIDFunction(),
- pybind11::cpp_function(
- [thisClass = thisClass](const py::object &mlirType) {
- return thisClass(mlirType);
- }));
+ getTypeIDFunction())(pybind11::cpp_function(
+ [thisClass = thisClass](const py::object &mlirType) {
+ return thisClass(mlirType);
+ }));
}
}
};
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 976297257ced06e..a022067f5c7e575 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -66,6 +66,13 @@ class PyGlobals {
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
bool replace = false);
+ /// Adds a user-friendly value caster. Raises an exception if the mapping
+ /// already exists and replace == false. This is intended to be called by
+ /// implementation code.
+ void registerValueCaster(MlirTypeID mlirTypeID,
+ pybind11::function valueCaster,
+ bool replace = false);
+
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
@@ -86,6 +93,10 @@ class PyGlobals {
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);
+ /// Returns the custom value caster for MlirTypeID mlirTypeID.
+ std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
+ MlirDialect dialect);
+
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
@@ -109,7 +120,8 @@ class PyGlobals {
llvm::StringMap<pybind11::object> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
-
+ /// Map of MlirTypeID to custom value caster.
+ llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7cfea31dbb2e80c..0f2ca666ccc050e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1899,13 +1899,28 @@ bool PyTypeID::operator==(const PyTypeID &other) const {
}
//------------------------------------------------------------------------------
-// PyValue and subclases.
+// PyValue and subclasses.
//------------------------------------------------------------------------------
pybind11::object PyValue::getCapsule() {
return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
}
+pybind11::object PyValue::maybeDownCast() {
+ MlirType type = mlirValueGetType(get());
+ MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
+ assert(!mlirTypeIDIsNull(mlirTypeID) &&
+ "mlirTypeID was expected to be non-null.");
+ std::optional<pybind11::function> valueCaster =
+ PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
+ // py::return_value_policy::move means use std::move to move the return value
+ // contents into a new instance that will be owned by Python.
+ py::object thisObj = py::cast(this, py::return_value_policy::move);
+ if (!valueCaster)
+ return thisObj;
+ return valueCaster.value()(thisObj);
+}
+
PyValue PyValue::createFromCapsule(pybind11::object capsule) {
MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
if (mlirValueIsNull(value))
@@ -2121,6 +2136,8 @@ class PyConcreteValue : public PyValue {
return DerivedTy::isaFunction(otherValue);
},
py::arg("other_value"));
+ cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](DerivedTy &self) { return self.maybeDownCast(); });
DerivedTy::bindDerived(cls);
}
@@ -2193,6 +2210,7 @@ class PyBlockArgumentList
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
public:
static constexpr const char *pyClassName = "BlockArgumentList";
+ using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
intptr_t startIndex = 0, intptr_t length = -1,
@@ -2241,6 +2259,7 @@ class PyBlockArgumentList
class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
public:
static constexpr const char *pyClassName = "OpOperandList";
+ using SliceableT = Sliceable<PyOpOperandList, PyValue>;
PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
@@ -2296,6 +2315,7 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
public:
static constexpr const char *pyClassName = "OpResultList";
+ using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
@@ -2303,7 +2323,7 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
length == -1 ? mlirOperationGetNumResults(operation->get())
: length,
step),
- operation(operation) {}
+ operation(std::move(operation)) {}
static void bindDerived(ClassTy &c) {
c.def_property_readonly("types", [](PyOpResultList &self) {
@@ -2892,7 +2912,8 @@ void mlir::python::populateIRCore(py::module &m) {
.str());
}
return PyOpResult(operation.getRef(),
- mlirOperationGetResult(operation, 0));
+ mlirOperationGetResult(operation, 0))
+ .maybeDownCast();
},
"Shortcut to get an op result if it has only one (throws an error "
"otherwise).")
@@ -3566,7 +3587,9 @@ void mlir::python::populateIRCore(py::module &m) {
[](PyValue &self, PyValue &with) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
- kValueReplaceAllUsesWithDocstring);
+ kValueReplaceAllUsesWithDocstring)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyValue &self) { return self.maybeDownCast(); });
PyBlockArgument::bind(m);
PyOpResult::bind(m);
PyOpOperand::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 6c5cde86236ce90..5538924d2481849 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -88,6 +88,16 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
found = std::move(typeCaster);
}
+void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
+ pybind11::function valueCaster,
+ bool replace) {
+ pybind11::object &found = valueCasterMap[mlirTypeID];
+ if (found && !replace)
+ throw std::runtime_error("Value caster is already registered: " +
+ py::repr(found).cast<std::string>());
+ found = std::move(valueCaster);
+}
+
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
@@ -134,6 +144,17 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
return std::nullopt;
}
+std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
+ MlirDialect dialect) {
+ loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+ const auto foundIt = valueCasterMap.find(mlirTypeID);
+ if (foundIt != valueCasterMap.end()) {
+ assert(foundIt->second && "value caster is defined");
+ return foundIt->second;
+ }
+ return std::nullopt;
+}
+
std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 01ee4975d0e9a91..af55693f18fbbf9 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -761,7 +761,7 @@ class PyRegion {
/// Wrapper around an MlirAsmState.
class PyAsmState {
- public:
+public:
PyAsmState(MlirValue value, bool useLocalScope) {
flags = mlirOpPrintingFlagsCreate();
// The OpPrintingFlags are not exposed Python side, create locally and
@@ -780,16 +780,14 @@ class PyAsmState {
state =
mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
}
- ~PyAsmState() {
- mlirOpPrintingFlagsDestroy(flags);
- }
+ ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
// Delete copy constructors.
PyAsmState(PyAsmState &other) = delete;
PyAsmState(const PyAsmState &other) = delete;
MlirAsmState get() { return state; }
- private:
+private:
MlirAsmState state;
MlirOpPrintingFlags flags;
};
@@ -1112,6 +1110,10 @@ class PyConcreteAttribute : public BaseTy {
/// bindings so such operation always exists).
class PyValue {
public:
+ // The virtual here is "load bearing" in that it enables RTTI
+ // for PyConcreteValue CRTP classes that support maybeDownCast.
+ // See PyValue::maybeDownCast.
+ virtual ~PyValue() = default;
PyValue(PyOperationRef parentOperation, MlirValue value)
: parentOperation(std::move(parentOperation)), value(value) {}
operator MlirValue() const { return value; }
@@ -1124,6 +1126,8 @@ class PyValue {
/// Gets a capsule wrapping the void* within the MlirValue.
pybind11::object getCapsule();
+ pybind11::object maybeDownCast();
+
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
static PyValue createFromCapsule(pybind11::object capsule);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 2ba3a3677198cbc..17272472ccca42a 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,8 +12,6 @@
#include "IRModule.h"
#include "Pass.h"
-#include <tuple>
-
namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
@@ -46,7 +44,8 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
- "operation_name"_a, "operation_class"_a, "replace"_a = false,
+ "operation_name"_a, "operation_class"_a, py::kw_only(),
+ "replace"_a = false,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
@@ -82,17 +81,32 @@ PYBIND11_MODULE(_mlir, m) {
return opClass;
});
},
- "dialect_class"_a, "replace"_a = false,
+ "dialect_class"_a, py::kw_only(), "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
- PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
- replace);
+ [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
+ return py::cpp_function([mlirTypeID,
+ replace](py::object typeCaster) -> py::object {
+ PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
+ return typeCaster;
+ });
},
- "typeid"_a, "type_caster"_a, "replace"_a = false,
+ "typeid"_a, py::kw_only(), "replace"_a = false,
"Register a type caster for casting MLIR types to custom user types.");
+ m.def(
+ MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
+ [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
+ return py::cpp_function(
+ [mlirTypeID, replace](py::object valueCaster) -> py::object {
+ PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
+ replace);
+ return valueCaster;
+ });
+ },
+ "typeid"_a, py::kw_only(), "replace"_a = false,
+ "Register a value caster for casting MLIR values to custom user values.");
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 2a8da20bee0495d..38462ac8ba6db9c 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -10,6 +10,7 @@
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
#include "mlir-c/Support.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/DataTypes.h"
@@ -228,6 +229,11 @@ class Sliceable {
return linearIndex;
}
+ /// Trait to check if T provides a `maybeDownCast` method.
+ /// Note, you need the & to detect inherited members.
+ template <typename T, typename... Args>
+ using has_maybe_downcast = decltype(&T::maybeDownCast);
+
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
@@ -239,8 +245,13 @@ class Sliceable {
return {};
}
- return pybind11::cast(
- static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
+ if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
+ return static_cast<Derived *>(this)
+ ->getRawElement(linearizeIndex(index))
+ .maybeDownCast();
+ else
+ return pybind11::cast(
+ static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
}
/// Returns a new instance of the pseudo-container restricted to the given
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 9cca7d659ec8cb3..60ce83c09f1717e 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -5,7 +5,12 @@
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
-from typing import Sequence as _Sequence, Union as _Union
+from typing import (
+ Sequence as _Sequence,
+ Type as _Type,
+ TypeVar as _TypeVar,
+ Union as _Union,
+)
__all__ = [
"equally_sized_accessor",
@@ -123,3 +128,9 @@ def get_op_result_or_op_results(
if len(op.results) > 0
else op
)
+
+
+# This is the standard way to indicate subclass/inheritance relationship
+# see the typing.Type doc string.
+_U = _TypeVar("_U", bound=_cext.ir.Value)
+SubClassValueT = _Type[_U]
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index cf4228c2a63a91b..18526ab8c3c02dc 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,7 +4,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
-from ._mlir_libs._mlir import register_type_caster
+from ._mlir_libs._mlir import register_type_caster, register_value_caster
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 63dad1cc901fe2b..f7df8ba2df0ae2f 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
}
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
+// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 6d1c5eab7589847..f80f2c084a0f3b8 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -1,8 +1,9 @@
# RUN: %PYTHON %s | FileCheck %s
+from functools import partialmethod
from mlir.ir import *
-import mlir.dialects.func as func
import mlir.dialects.arith as arith
+import mlir.dialects.func as func
def run(f):
@@ -35,14 +36,59 @@ def testFastMathFlags():
print(r)
-# CHECK-LABEL: TEST: testArithValueBuilder
+# CHECK-LABEL: TEST: testArithValue
@run
-def testArithValueBuilder():
+def testArithValue():
+ def _binary_op(lhs, rhs, op: str) -> "ArithValue":
+ op = op.capitalize()
+ if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+ op += "F"
+ elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
+ lhs.type
+ ):
+ op += "I"
+ else:
+ raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
+
+ op = getattr(arith, f"{op}Op")
+ return op(lhs, rhs).result
+
+ @register_value_caster(F16Type.static_typeid)
+ @register_value_caster(F32Type.static_typeid)
+ @register_value_caster(F64Type.static_typeid)
+ @register_value_caster(IntegerType.static_typeid)
+ class ArithValue(Value):
+ def __init__(self, v):
+ super().__init__(v)
+
+ __add__ = partialmethod(_binary_op, op="add")
+ __sub__ = partialmethod(_binary_op, op="sub")
+ __mul__ = partialmethod(_binary_op, op="mul")
+
+ def __str__(self):
+ return super().__str__().replace(Value.__name__, ArithValue.__name__)
+
with Context() as ctx, Location.unknown():
module = Module.create()
+ f16_t = F16Type.get()
f32_t = F32Type.get()
+ f64_t = F64Type.get()
with InsertionPoint(module.body):
- a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
- # CHECK: %cst = arith.constant 4.242000e+01 : f32
+ a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+ # CHECK: ArithValue(%cst = arith.constant 4.240
print(a)
+
+ b = a + a
+ # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
+ print(b)
+
+ a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
+ b = a - a
+ # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
+ print(b)
+
+ a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
+ b = a * a
+ # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
+ print(b)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 472db7e5124dbed..f313a400b73c0a5 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -425,6 +425,12 @@ def __str__(self):
# And it should be equal to the in-tree concrete type
assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
+ d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result
+ # CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>)
+ print(d)
+ # CHECK: TestTensorValue
+ print(repr(d))
+
# CHECK-LABEL: TEST: inferReturnTypeComponents
@run
@@ -502,19 +508,18 @@ def testCustomTypeTypeCaster():
# CHECK: Type caster is already registered
try:
+ @register_type_caster(c.typeid)
def type_caster(pytype):
return test.TestIntegerRankedTensorType(pytype)
- register_type_caster(c.typeid, type_caster)
except RuntimeError as e:
print(e)
- def type_caster(pytype):
- return RankedTensorType(pytype)
-
# python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
# So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
- register_type_caster(c.typeid, type_caster, replace=True)
+ @register_type_caster(c.typeid, replace=True)
+ def type_caster(pytype):
+ return RankedTensorType(pytype)
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
# CHECK: tensor<10x10xi5>
@@ -522,11 +527,10 @@ def type_caster(pytype):
# CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
print("ranked tensor type", repr(d.type))
+ @register_type_caster(c.typeid, replace=True)
def type_caster(pytype):
return test.TestIntegerRankedTensorType(pytype)
- register_type_caster(c.typeid, type_caster, replace=True)
-
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
# CHECK: tensor<10x10xi5>
print(d.type)
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index ddf653dcce27804..acbf463113a6d59 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -3,6 +3,7 @@
import gc
from mlir.ir import *
from mlir.dialects import func
+from mlir.dialects._ods_common import SubClassValueT
def run(f):
@@ -270,3 +271,120 @@ def testValueSetType():
# CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
print(value.owner)
+
+
+# CHECK-LABEL: TEST: testValueCasters
+ at run
+def testValueCasters():
+ class NOPResult(OpResult):
+ def __init__(self, v):
+ super().__init__(v)
+
+ def __str__(self):
+ return super().__str__().replace(Value.__name__, NOPResult.__name__)
+
+ class NOPValue(Value):
+ def __init__(self, v):
+ super().__init__(v)
+
+ def __str__(self):
+ return super().__str__().replace(Value.__name__, NOPValue.__name__)
+
+ class NOPBlockArg(BlockArgument):
+ def __init__(self, v):
+ super().__init__(v)
+
+ def __str__(self):
+ return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
+
+ @register_value_caster(IntegerType.static_typeid)
+ def cast_int(v) -> SubClassValueT:
+ print("in caster", v.__class__.__name__)
+ if isinstance(v, OpResult):
+ return NOPResult(v)
+ if isinstance(v, BlockArgument):
+ return NOPBlockArg(v)
+ elif isinstance(v, Value):
+ return NOPValue(v)
+
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ values = Operation.create("custom.op1", results=[i32, i32]).results
+ # CHECK: in caster OpResult
+ # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", values[0].result_number, values[0])
+ # CHECK: in caster OpResult
+ # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", values[1].result_number, values[1])
+
+ # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("results slice", values[:1][0].result_number, values[:1][0])
+
+ value0, value1 = values
+ # CHECK: in caster OpResult
+ # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", value0.result_number, values[0])
+ # CHECK: in caster OpResult
+ # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", value1.result_number, values[1])
+
+ op1 = Operation.create("custom.op2", operands=[value0, value1])
+ # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
+ print(op1)
+
+ # CHECK: in caster Value
+ # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("operand 0", op1.operands[0])
+ # CHECK: in caster Value
+ # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("operand 1", op1.operands[1])
+
+ # CHECK: in caster BlockArgument
+ # CHECK: in caster BlockArgument
+ @func.FuncOp.from_py_func(i32, i32)
+ def reduction(arg0, arg1):
+ # CHECK: as func arg 0 NOPBlockArg
+ print("as func arg", arg0.arg_number, arg0.__class__.__name__)
+ # CHECK: as func arg 1 NOPBlockArg
+ print("as func arg", arg1.arg_number, arg1.__class__.__name__)
+
+ # CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0)
+ print(
+ "args slice",
+ reduction.func_op.arguments[:1][0].arg_number,
+ reduction.func_op.arguments[:1][0],
+ )
+
+ try:
+
+ @register_value_caster(IntegerType.static_typeid)
+ def dont_cast_int_shouldnt_register(v):
+ ...
+
+ except RuntimeError as e:
+ # CHECK: Value caster is already registered: {{.*}}cast_int
+ print(e)
+
+ @register_value_caster(IntegerType.static_typeid, replace=True)
+ def dont_cast_int(v) -> OpResult:
+ assert isinstance(v, OpResult)
+ print("don't cast", v.result_number, v)
+ return v
+
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32)
+ new_value = Operation.create("custom.op1", results=[i32]).result
+ # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32)
+ print("result", new_value.result_number, new_value)
+
+ # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32)
+ new_value = Operation.create("custom.op2", results=[i32]).results[0]
+ # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32)
+ print("result", new_value.result_number, new_value)
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f533082a0a147c0..aff414894cb825a 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -42,6 +42,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
return cls(mlirPythonTestTestAttributeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
+
mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
mlirPythonTestTestTypeGetTypeID)
.def_classmethod(
@@ -50,7 +51,8 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
- auto cls =
+
+ auto typeCls =
mlir_type_subclass(m, "TestIntegerRankedTensorType",
mlirTypeIsARankedIntegerTensor,
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
@@ -65,16 +67,40 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
encoding));
},
"cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
- assert(py::hasattr(cls.get_class(), "static_typeid") &&
+
+ assert(py::hasattr(typeCls.get_class(), "static_typeid") &&
"TestIntegerRankedTensorType has no static_typeid");
- MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID();
+
+ MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
+
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID,
+ "replace"_a = true)(
+ pybind11::cpp_function([typeCls](const py::object &mlirType) {
+ return typeCls.get_class()(mlirType);
+ }));
+
+ auto valueCls = mlir_value_subclass(m, "TestTensorValue",
+ mlirTypeIsAPythonTestTestTensorValue)
+ .def("is_null", [](MlirValue &self) {
+ return mlirValueIsNull(self);
+ });
+
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
- mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) {
- return cls.get_class()(mlirType);
- }),
- /*replace=*/true);
- mlir_value_subclass(m, "TestTensorValue",
- mlirTypeIsAPythonTestTestTensorValue)
- .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+ .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
+ mlirRankedTensorTypeID)(
+ pybind11::cpp_function([valueCls](const py::object &valueObj) {
+ py::object capsule = mlirApiObjectToCapsule(valueObj);
+ MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
+ MlirType t = mlirValueGetType(v);
+ // This is hyper-specific in order to exercise/test registering a
+ // value caster from cpp (but only for a single test case; see
+ // testTensorValue python_test.py).
+ if (mlirShapedTypeHasStaticShape(t) &&
+ mlirShapedTypeGetDimSize(t, 0) == 1 &&
+ mlirShapedTypeGetDimSize(t, 1) == 2 &&
+ mlirShapedTypeGetDimSize(t, 2) == 3)
+ return valueCls.get_class()(valueObj);
+ return valueObj;
+ }));
}
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index c8ef84721090ab9..0c0ad2cfeffdcc2 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,15 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
-from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
+from ._ods_common import (
+ SubClassValueT as _SubClassValueT,
+ equally_sized_accessor as _ods_equally_sized_accessor,
+ get_default_loc_context as _ods_get_default_loc_context,
+ get_op_result_or_op_results as _get_op_result_or_op_results,
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ segmented_accessor as _ods_segmented_accessor,
+)
_ods_ir = _ods_cext.ir
import builtins
@@ -1004,8 +1012,8 @@ static void emitValueBuilder(const Operator &op,
llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "),
(op.getNumResults() > 1
- ? "_Sequence[_ods_ir.OpResult]"
- : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+ ? "_Sequence[_SubClassValueT]"
+ : (op.getNumResults() > 0 ? "_SubClassValueT"
: "_ods_ir.Operation")));
}
More information about the Mlir-commits
mailing list