[all-commits] [llvm/llvm-project] 7c8508: [mlir][python] value casting (#69644)
Maksim Levental via All-commits
all-commits at lists.llvm.org
Tue Nov 7 08:49:55 PST 2023
Branch: refs/heads/main
Home: https://github.com/llvm/llvm-project
Commit: 7c850867b9ef4427375da6d83c34d0b9c944fcb8
https://github.com/llvm/llvm-project/commit/7c850867b9ef4427375da6d83c34d0b9c944fcb8
Author: Maksim Levental <maksim.levental at gmail.com>
Date: 2023-11-07 (Tue, 07 Nov 2023)
Changed paths:
M mlir/include/mlir-c/Bindings/Python/Interop.h
M mlir/include/mlir/Bindings/Python/PybindAdaptors.h
M mlir/lib/Bindings/Python/Globals.h
M mlir/lib/Bindings/Python/IRCore.cpp
M mlir/lib/Bindings/Python/IRModule.cpp
M mlir/lib/Bindings/Python/IRModule.h
M mlir/lib/Bindings/Python/MainModule.cpp
M mlir/lib/Bindings/Python/PybindUtils.h
M mlir/python/mlir/dialects/_ods_common.py
M mlir/python/mlir/ir.py
M mlir/test/mlir-tblgen/op-python-bindings.td
M mlir/test/python/dialects/arith_dialect.py
M mlir/test/python/dialects/python_test.py
M mlir/test/python/ir/value.py
M mlir/test/python/lib/PythonTestModule.cpp
M mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Log Message:
-----------
[mlir][python] value casting (#69644)
This PR adds "value casting", i.e., a mechanism to wrap `ir.Value` in a
proxy class that overloads dunders such as `__add__`, `__sub__`, and
`__mul__` for fun and great profit.
This is thematically similar to
https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb
and
https://github.com/llvm/llvm-project/commit/9566ee280607d91fa2e5eca730a6765ac84dfd0f.
The example in the test demonstrates the value of the feature (no pun
intended):
```python
@register_value_caster(F16Type.static_typeid)
@register_value_caster(F32Type.static_typeid)
@register_value_caster(F64Type.static_typeid)
@register_value_caster(IntegerType.static_typeid)
class ArithValue(Value):
__add__ = partialmethod(_binary_op, op="add")
__sub__ = partialmethod(_binary_op, op="sub")
__mul__ = partialmethod(_binary_op, op="mul")
a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
b = a + a
# CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
print(b)
a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
b = a - a
# CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
print(b)
a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
print(b)
```
**EDIT**: this now goes through the bindings and thus supports automatic
casting of `OpResult` (including as an element of `OpResultList`),
`BlockArgument` (including as an element of `BlockArgumentList`), as
well as `Value`.
More information about the All-commits
mailing list