[Mlir-commits] [mlir] [mlir][python] value casting (PR #69644)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 19 15:29:42 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

This PR adds "value casting", i.e., a mechanism to wrap `ir.Value` in a proxy class that overloads dunders such as `__add__`, `__sub__`, and `__mul__` for fun and great profit. 

This is thematically similar to https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb and https://github.com/llvm/llvm-project/commit/9566ee280607d91fa2e5eca730a6765ac84dfd0f (and indeed relies on the former) but is implemented at the Python API level because 

1. it's quite tough to intervene in all the places `ir.Value` is returned from the pybind code, in particular [`PyOpResultList::slice`](https://github.com/llvm/llvm-project/blob/5a600c23f9e01f58bac09a8fad096e194fc90ae2/mlir/lib/Bindings/Python/IRCore.cpp#L2303);
2. registering Python classes to be conjured up by the pybind code is fairly brittle IMHO;
3. it's always better to have things be purely in Python so they can more easily be understood/patched/monkey-patched.

This has been my secret (lol) plan for quite a while now and it wasn't possible until https://github.com/llvm/llvm-project/pull/68308 but now we've arrived.

The design looks like this:

- a `ValueCasterT` is a callable that takes an `ir.Value` (or `ir.OpResult`) and returns a subclass of `ir.Value` or `None`;
- the user registers a value caster associated with a `TypeID` (using `register_value_caster`) that is responsible for casting `ir.Value`s with a matching `type.typeid`;
- an `ir.Value` produced by a value builder will be `maybe_cast`ed, depending on whether there's a value caster registered for `ir.Value.type.typeid`

In the absence of a registered value caster, the `ir.Value` or `ir.OpResult` is just returned unscathed.

The user can optionally provide a `priority` when registering the value caster in order to allow for more/less specific proxies; for example, one might want to make attempt at wrapping a `tensor<10x10x!tt.ptr<f32>>` (a tensor of [triton pointers](https://github.com/makslevental/triton-pp/blob/main/triton_pp/dialects/ext/triton.py#L380)), which has the same `TypeID` as `RankedTensorType`, which might already have been [wrapped (possibly upstream)](https://github.com/makslevental/mlir-python-utils/blob/main/mlir/utils/dialects/ext/tensor.py#L111). Should the higher priority caster fail, it can return `None` and the next priority caster will be tried.

**Note, this doesn't affect any paths other than through the "value builders"**.

cc @<!-- -->stephenneuendorffer 

---

Patch is 21.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69644.diff


5 Files Affected:

- (modified) mlir/python/mlir/dialects/_ods_common.py (+57-1) 
- (modified) mlir/python/mlir/ir.py (+14) 
- (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+24-24) 
- (modified) mlir/test/python/dialects/arith_dialect.py (+63-5) 
- (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+13-4) 


``````````diff
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 9cca7d659ec8cb3..dd41ee63c8bf7af 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -1,11 +1,18 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from collections import defaultdict
 
 # Provide a convenient name for sub-packages to resolve the main C-extension
 # with a relative import.
 from .._mlir_libs import _mlir as _cext
-from typing import Sequence as _Sequence, Union as _Union
+from typing import (
+    Callable as _Callable,
+    Sequence as _Sequence,
+    Type as _Type,
+    TypeVar as _TypeVar,
+    Union as _Union,
+)
 
 __all__ = [
     "equally_sized_accessor",
@@ -123,3 +130,52 @@ def get_op_result_or_op_results(
         if len(op.results) > 0
         else op
     )
+
+
+U = _TypeVar("U", bound=_cext.ir.Value)
+SubClassValueT = _Type[U]
+
+ValueCasterT = _Callable[
+    [_Union[_cext.ir.Value, _cext.ir.OpResult]], _Union[SubClassValueT, None]
+]
+
+_VALUE_CASTERS: defaultdict[
+    _cext.ir.TypeID,
+    _Sequence[ValueCasterT],
+] = defaultdict(list)
+
+
+def has_value_caster(typeid: _cext.ir.TypeID):
+    if not isinstance(typeid, _cext.ir.TypeID):
+        raise ValueError(f"{typeid=} is not a TypeID")
+    if typeid in _VALUE_CASTERS:
+        return True
+    return False
+
+
+def get_value_caster(typeid: _cext.ir.TypeID):
+    if not has_value_caster(typeid):
+        raise ValueError(f"no registered caster for {typeid=}")
+    return _VALUE_CASTERS[typeid]
+
+
+def maybe_cast(
+    val: _Union[
+        _cext.ir.Value,
+        _cext.ir.OpResult,
+        _Sequence[_cext.ir.Value],
+        _Sequence[_cext.ir.OpResult],
+        _cext.ir.Operation,
+    ]
+) -> _Union[SubClassValueT, _Sequence[SubClassValueT], _cext.ir.Operation]:
+    if isinstance(val, (tuple, list)):
+        return tuple(map(maybe_cast, val))
+
+    if not isinstance(val, _cext.ir.Value) and not isinstance(val, _cext.ir.OpResult):
+        return val
+
+    if has_value_caster(val.type.typeid):
+        for caster in get_value_caster(val.type.typeid):
+            if casted := caster(val):
+                return casted
+    return val
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 43553f3118a51fc..6e1f2b357f31711 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,6 +5,20 @@
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
 from ._mlir_libs._mlir import register_type_caster
+from .dialects._ods_common import ValueCasterT, _VALUE_CASTERS
+
+
+def register_value_caster(typeid: TypeID, priority: int = None):
+    def wrapper(caster: ValueCasterT):
+        if not isinstance(typeid, TypeID):
+            raise ValueError(f"{typeid=} is not a TypeID")
+        if priority is None:
+            _VALUE_CASTERS[typeid].append(caster)
+        else:
+            _VALUE_CASTERS[typeid].insert(priority, caster)
+        return caster
+
+    return wrapper
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 63dad1cc901fe2b..96b0c170dc5bb40 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -61,7 +61,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
 }
 
 // CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -108,7 +108,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
 }
 
 // CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -158,7 +158,7 @@ def AttributedOp : TestOp<"attributed_op"> {
 }
 
 // CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK:     return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+// CHECK:     return _maybe_cast(_get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -194,7 +194,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
 }
 
 // CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -218,7 +218,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
 }
 
 // CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -236,7 +236,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
 }
 
 // CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -246,7 +246,7 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
 }
 
 // CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class EmptyOp(_ods_ir.OpView):
@@ -263,7 +263,7 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
 // CHECK: def empty(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -276,7 +276,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
 }
 
 // CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -289,7 +289,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
 }
 
 // CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -327,7 +327,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
 }
 
 // CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -358,7 +358,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
 }
 
 // CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -390,7 +390,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
 }
 
 // CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -423,7 +423,7 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
 }
 
 // CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class PythonKeywordOp(_ods_ir.OpView):
@@ -447,7 +447,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
 }
 
 // CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results"
 def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -461,7 +461,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
 }
 
 // CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -471,7 +471,7 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
 }
 
 // CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)))
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -498,7 +498,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
 }
 
 // CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -524,7 +524,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
 }
 
 // CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SimpleOp(_ods_ir.OpView):
@@ -564,7 +564,7 @@ def SimpleOp : TestOp<"simple"> {
 }
 
 // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)))
 
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -591,7 +591,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
 }
 
 // CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
 
 // CHECK: class VariadicRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -614,7 +614,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
 }
 
 // CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -623,7 +623,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
 }
 
 // CHECK: def _123with__special_characters(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
 }
 
 // CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)))
\ No newline at end of file
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 6d1c5eab7589847..180d30ff4cfb3e5 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -1,8 +1,9 @@
 # RUN: %PYTHON %s | FileCheck %s
+from functools import partialmethod
 
 from mlir.ir import *
-import mlir.dialects.func as func
 import mlir.dialects.arith as arith
+from mlir.dialects._ods_common import maybe_cast
 
 
 def run(f):
@@ -35,14 +36,71 @@ def testFastMathFlags():
             print(r)
 
 
-# CHECK-LABEL: TEST: testArithValueBuilder
+# CHECK-LABEL: TEST: testArithValue
 @run
-def testArithValueBuilder():
+def testArithValue():
+    def _binary_op(lhs, rhs, op: str):
+        op = op.capitalize()
+        if arith._is_float_type(lhs.type):
+            op += "F"
+        elif arith._is_integer_like_type(lhs.type):
+            op += "I"
+        else:
+            raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
+
+        op = getattr(arith, f"{op}Op")
+        return maybe_cast(op(lhs, rhs).result)
+
+    @register_value_caster(F16Type.static_typeid)
+    @register_value_caster(F32Type.static_typeid)
+    @register_value_caster(F64Type.static_typeid)
+    @register_value_caster(IntegerType.static_typeid)
+    class ArithValue(Value):
+        __add__ = partialmethod(_binary_op, op="add")
+        __sub__ = partialmethod(_binary_op, op="sub")
+        __mul__ = partialmethod(_binary_op, op="mul")
+
+        def __str__(self):
+            return super().__str__().replace("Value", "ArithValue")
+
+    @register_value_caster(IntegerType.static_typeid, priority=0)
+    class ArithValue1(Value):
+        __mul__ = partialmethod(_binary_op, op="mul")
+
+        def __str__(self):
+            return super().__str__().replace("Value", "ArithValue1")
+
+    @register_value_caster(IntegerType.static_typeid, priority=0)
+    def no_op_caster(val):
+        print("no_op_caster", val)
+        return None
+
     with Context() as ctx, Location.unknown():
         module = Module.create()
+        f16_t = F16Type.get()
         f32_t = F32Type.get()
+        f64_t = F64Type.get()
+        i32 = IntegerType.get_signless(32)
 
         with InsertionPoint(module.body):
+            a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+            b = a + a
+            # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
+            print(b)
+
             a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
-            # CHECK: %cst = arith.constant 4.242000e+01 : f32
-            print(a)
+            b = a - a
+            # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
+            print(b)
+
+            a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
+            b = a * a
+            # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
+            print(b)
+
+            # CHECK: no_op_caster Value(%c1_i32 = arith.constant 1 : i32)
+            a = arith.constant(value=IntegerAttr.get(i32, 1))
+            b = a * a
+            # CHECK: no_op_caster Value(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+            # CHECK: ArithValue1(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+            print(b)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index c8ef84721090ab9..170ac6b87c693d7 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,16 @@ constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from ._ods_common import _cext as _ods_cext
-from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_resul...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/69644


More information about the Mlir-commits mailing list