[all-commits] [llvm/llvm-project] 7c8508: [mlir][python] value casting (#69644)

Maksim Levental via All-commits all-commits at lists.llvm.org
Tue Nov 7 08:49:55 PST 2023


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: 7c850867b9ef4427375da6d83c34d0b9c944fcb8
      https://github.com/llvm/llvm-project/commit/7c850867b9ef4427375da6d83c34d0b9c944fcb8
  Author: Maksim Levental <maksim.levental at gmail.com>
  Date:   2023-11-07 (Tue, 07 Nov 2023)

  Changed paths:
    M mlir/include/mlir-c/Bindings/Python/Interop.h
    M mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    M mlir/lib/Bindings/Python/Globals.h
    M mlir/lib/Bindings/Python/IRCore.cpp
    M mlir/lib/Bindings/Python/IRModule.cpp
    M mlir/lib/Bindings/Python/IRModule.h
    M mlir/lib/Bindings/Python/MainModule.cpp
    M mlir/lib/Bindings/Python/PybindUtils.h
    M mlir/python/mlir/dialects/_ods_common.py
    M mlir/python/mlir/ir.py
    M mlir/test/mlir-tblgen/op-python-bindings.td
    M mlir/test/python/dialects/arith_dialect.py
    M mlir/test/python/dialects/python_test.py
    M mlir/test/python/ir/value.py
    M mlir/test/python/lib/PythonTestModule.cpp
    M mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

  Log Message:
  -----------
  [mlir][python] value casting (#69644)

This PR adds "value casting", i.e., a mechanism to wrap `ir.Value` in a
proxy class that overloads dunders such as `__add__`, `__sub__`, and
`__mul__` for fun and great profit.

This is thematically similar to
https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb
and
https://github.com/llvm/llvm-project/commit/9566ee280607d91fa2e5eca730a6765ac84dfd0f.
The example in the test demonstrates the value of the feature (no pun
intended):

```python
    @register_value_caster(F16Type.static_typeid)
    @register_value_caster(F32Type.static_typeid)
    @register_value_caster(F64Type.static_typeid)
    @register_value_caster(IntegerType.static_typeid)
    class ArithValue(Value):
        __add__ = partialmethod(_binary_op, op="add")
        __sub__ = partialmethod(_binary_op, op="sub")
        __mul__ = partialmethod(_binary_op, op="mul")

    a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
    b = a + a
    # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
    print(b)

    a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
    b = a - a
    # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
    print(b)

    a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
    b = a * a
    # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
    print(b)
```

**EDIT**: this now goes through the bindings and thus supports automatic
casting of `OpResult` (including as an element of `OpResultList`),
`BlockArgument` (including as an element of `BlockArgumentList`), as
well as `Value`.




More information about the All-commits mailing list