[Mlir-commits] [mlir] [mlir][python] value casting (PR #69644)

Maksim Levental llvmlistbot at llvm.org
Mon Oct 30 11:34:43 PDT 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/69644

>From a91112fa8df552f12f3ee582112c85d14acc400b Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 25 Oct 2023 17:14:08 -0500
Subject: [PATCH 1/5] [mlir][python] fix typecaster replace

---
 mlir/lib/Bindings/Python/IRModule.cpp    |  8 ++++++++
 mlir/test/python/dialects/python_test.py | 11 +++++++++++
 mlir/test/python/ir/operation.py         | 24 ++++++++++++++++++++++++
 3 files changed, 43 insertions(+)

diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index a1c8ab7a09ce155..f8e22f7bb0c1ba7 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -82,6 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
   if (found && !found.is_none() && !replace)
     throw std::runtime_error("Type caster is already registered");
   found = std::move(typeCaster);
+  const auto foundIt = typeCasterMapCache.find(mlirTypeID);
+  if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
+    typeCasterMapCache[mlirTypeID] = found;
+  }
 }
 
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
@@ -104,6 +108,10 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
                                  .str());
   }
   found = std::move(pyClass);
+  auto foundIt = operationClassMapCache.find(operationName);
+  if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
+    operationClassMapCache[operationName] = found;
+  }
 }
 
 std::optional<py::function>
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 651e6554eebe8bd..a70b6fd5e5e4d84 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -510,6 +510,17 @@ def type_caster(pytype):
         except RuntimeError as e:
             print(e)
 
+        def type_caster(pytype):
+            return RankedTensorType(pytype)
+
+        register_type_caster(c.typeid, type_caster, replace=True)
+
+        d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
+        # CHECK: tensor<10x10xi5>
+        print(d.type)
+        # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
+        print("ranked tensor type", repr(d.type))
+
         def type_caster(pytype):
             return test.TestIntegerRankedTensorType(pytype)
 
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 129b7fa744e4721..5ded4814e54bf66 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -5,6 +5,8 @@
 import itertools
 from mlir.ir import *
 from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.dialects._ods_common import _cext
 
 
 def run(f):
@@ -646,6 +648,7 @@ def testKnownOpView():
       %1 = "custom.f32"() : () -> f32
       %2 = "custom.f32"() : () -> f32
       %3 = arith.addf %1, %2 : f32
+      %4 = arith.constant 0 : i32
     """
         )
         print(module)
@@ -668,6 +671,27 @@ def testKnownOpView():
         # CHECK: OpView object
         print(repr(custom))
 
+        # constant should map to an extension OpView class in the arithmetic dialect.
+        constant = module.body.operations[3]
+        # CHECK: <mlir.dialects.arith.ConstantOp object
+        print(repr(constant))
+        # CHECK: literal value 0
+        print("literal value", constant.literal_value)
+
+        @_cext.register_operation(arith._Dialect, replace=True)
+        class ConstantOp(arith.ConstantOp):
+            def __init__(self, result, value, *, loc=None, ip=None):
+                if isinstance(value, int):
+                    super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
+                elif isinstance(value, float):
+                    super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+                else:
+                    super().__init__(value, loc=loc, ip=ip)
+
+        constant = module.body.operations[3]
+        # CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object
+        print(repr(constant))
+
 
 # CHECK-LABEL: TEST: testSingleResultProperty
 @run

>From eecb8a056636b78d5dc03bc172134f1e98bf534f Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 11 Oct 2023 00:28:04 -0500
Subject: [PATCH 2/5] [mlir][python] value casting

---
 mlir/python/mlir/dialects/_ods_common.py      | 58 +++++++++++++++-
 mlir/python/mlir/ir.py                        | 14 ++++
 mlir/test/mlir-tblgen/op-python-bindings.td   | 48 ++++++-------
 mlir/test/python/dialects/arith_dialect.py    | 68 +++++++++++++++++--
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 17 +++--
 5 files changed, 171 insertions(+), 34 deletions(-)

diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 9cca7d659ec8cb3..dd41ee63c8bf7af 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -1,11 +1,18 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from collections import defaultdict
 
 # Provide a convenient name for sub-packages to resolve the main C-extension
 # with a relative import.
 from .._mlir_libs import _mlir as _cext
-from typing import Sequence as _Sequence, Union as _Union
+from typing import (
+    Callable as _Callable,
+    Sequence as _Sequence,
+    Type as _Type,
+    TypeVar as _TypeVar,
+    Union as _Union,
+)
 
 __all__ = [
     "equally_sized_accessor",
@@ -123,3 +130,52 @@ def get_op_result_or_op_results(
         if len(op.results) > 0
         else op
     )
+
+
+U = _TypeVar("U", bound=_cext.ir.Value)
+SubClassValueT = _Type[U]
+
+ValueCasterT = _Callable[
+    [_Union[_cext.ir.Value, _cext.ir.OpResult]], _Union[SubClassValueT, None]
+]
+
+_VALUE_CASTERS: defaultdict[
+    _cext.ir.TypeID,
+    _Sequence[ValueCasterT],
+] = defaultdict(list)
+
+
+def has_value_caster(typeid: _cext.ir.TypeID):
+    if not isinstance(typeid, _cext.ir.TypeID):
+        raise ValueError(f"{typeid=} is not a TypeID")
+    if typeid in _VALUE_CASTERS:
+        return True
+    return False
+
+
+def get_value_caster(typeid: _cext.ir.TypeID):
+    if not has_value_caster(typeid):
+        raise ValueError(f"no registered caster for {typeid=}")
+    return _VALUE_CASTERS[typeid]
+
+
+def maybe_cast(
+    val: _Union[
+        _cext.ir.Value,
+        _cext.ir.OpResult,
+        _Sequence[_cext.ir.Value],
+        _Sequence[_cext.ir.OpResult],
+        _cext.ir.Operation,
+    ]
+) -> _Union[SubClassValueT, _Sequence[SubClassValueT], _cext.ir.Operation]:
+    if isinstance(val, (tuple, list)):
+        return tuple(map(maybe_cast, val))
+
+    if not isinstance(val, _cext.ir.Value) and not isinstance(val, _cext.ir.OpResult):
+        return val
+
+    if has_value_caster(val.type.typeid):
+        for caster in get_value_caster(val.type.typeid):
+            if casted := caster(val):
+                return casted
+    return val
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 43553f3118a51fc..6e1f2b357f31711 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,6 +5,20 @@
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
 from ._mlir_libs._mlir import register_type_caster
+from .dialects._ods_common import ValueCasterT, _VALUE_CASTERS
+
+
+def register_value_caster(typeid: TypeID, priority: int = None):
+    def wrapper(caster: ValueCasterT):
+        if not isinstance(typeid, TypeID):
+            raise ValueError(f"{typeid=} is not a TypeID")
+        if priority is None:
+            _VALUE_CASTERS[typeid].append(caster)
+        else:
+            _VALUE_CASTERS[typeid].insert(priority, caster)
+        return caster
+
+    return wrapper
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 63dad1cc901fe2b..96b0c170dc5bb40 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -61,7 +61,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
 }
 
 // CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -108,7 +108,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
 }
 
 // CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -158,7 +158,7 @@ def AttributedOp : TestOp<"attributed_op"> {
 }
 
 // CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK:     return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+// CHECK:     return _maybe_cast(_get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -194,7 +194,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
 }
 
 // CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -218,7 +218,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
 }
 
 // CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -236,7 +236,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
 }
 
 // CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -246,7 +246,7 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
 }
 
 // CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class EmptyOp(_ods_ir.OpView):
@@ -263,7 +263,7 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
 // CHECK: def empty(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -276,7 +276,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
 }
 
 // CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -289,7 +289,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
 }
 
 // CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -327,7 +327,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
 }
 
 // CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -358,7 +358,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
 }
 
 // CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -390,7 +390,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
 }
 
 // CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -423,7 +423,7 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
 }
 
 // CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class PythonKeywordOp(_ods_ir.OpView):
@@ -447,7 +447,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
 }
 
 // CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results"
 def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -461,7 +461,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
 }
 
 // CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -471,7 +471,7 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
 }
 
 // CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)))
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -498,7 +498,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
 }
 
 // CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -524,7 +524,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
 }
 
 // CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SimpleOp(_ods_ir.OpView):
@@ -564,7 +564,7 @@ def SimpleOp : TestOp<"simple"> {
 }
 
 // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)))
 
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -591,7 +591,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
 }
 
 // CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
 
 // CHECK: class VariadicRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -614,7 +614,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
 }
 
 // CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -623,7 +623,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
 }
 
 // CHECK: def _123with__special_characters(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
 }
 
 // CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)))
\ No newline at end of file
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 6d1c5eab7589847..180d30ff4cfb3e5 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -1,8 +1,9 @@
 # RUN: %PYTHON %s | FileCheck %s
+from functools import partialmethod
 
 from mlir.ir import *
-import mlir.dialects.func as func
 import mlir.dialects.arith as arith
+from mlir.dialects._ods_common import maybe_cast
 
 
 def run(f):
@@ -35,14 +36,71 @@ def testFastMathFlags():
             print(r)
 
 
-# CHECK-LABEL: TEST: testArithValueBuilder
+# CHECK-LABEL: TEST: testArithValue
 @run
-def testArithValueBuilder():
+def testArithValue():
+    def _binary_op(lhs, rhs, op: str):
+        op = op.capitalize()
+        if arith._is_float_type(lhs.type):
+            op += "F"
+        elif arith._is_integer_like_type(lhs.type):
+            op += "I"
+        else:
+            raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
+
+        op = getattr(arith, f"{op}Op")
+        return maybe_cast(op(lhs, rhs).result)
+
+    @register_value_caster(F16Type.static_typeid)
+    @register_value_caster(F32Type.static_typeid)
+    @register_value_caster(F64Type.static_typeid)
+    @register_value_caster(IntegerType.static_typeid)
+    class ArithValue(Value):
+        __add__ = partialmethod(_binary_op, op="add")
+        __sub__ = partialmethod(_binary_op, op="sub")
+        __mul__ = partialmethod(_binary_op, op="mul")
+
+        def __str__(self):
+            return super().__str__().replace("Value", "ArithValue")
+
+    @register_value_caster(IntegerType.static_typeid, priority=0)
+    class ArithValue1(Value):
+        __mul__ = partialmethod(_binary_op, op="mul")
+
+        def __str__(self):
+            return super().__str__().replace("Value", "ArithValue1")
+
+    @register_value_caster(IntegerType.static_typeid, priority=0)
+    def no_op_caster(val):
+        print("no_op_caster", val)
+        return None
+
     with Context() as ctx, Location.unknown():
         module = Module.create()
+        f16_t = F16Type.get()
         f32_t = F32Type.get()
+        f64_t = F64Type.get()
+        i32 = IntegerType.get_signless(32)
 
         with InsertionPoint(module.body):
+            a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+            b = a + a
+            # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
+            print(b)
+
             a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
-            # CHECK: %cst = arith.constant 4.242000e+01 : f32
-            print(a)
+            b = a - a
+            # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
+            print(b)
+
+            a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
+            b = a * a
+            # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
+            print(b)
+
+            # CHECK: no_op_caster Value(%c1_i32 = arith.constant 1 : i32)
+            a = arith.constant(value=IntegerAttr.get(i32, 1))
+            b = a * a
+            # CHECK: no_op_caster Value(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+            # CHECK: ArithValue1(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+            print(b)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index c8ef84721090ab9..170ac6b87c693d7 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,16 @@ constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from ._ods_common import _cext as _ods_cext
-from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
+from ._ods_common import (
+    SubClassValueT as _SubClassValueT,
+    equally_sized_accessor as _ods_equally_sized_accessor,
+    get_default_loc_context as _ods_get_default_loc_context,
+    get_op_result_or_op_results as _get_op_result_or_op_results,
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+    maybe_cast as _maybe_cast,
+    segmented_accessor as _ods_segmented_accessor,
+)
 _ods_ir = _ods_cext.ir
 
 import builtins
@@ -263,7 +272,7 @@ constexpr const char *regionAccessorTemplate = R"Py(
 
 constexpr const char *valueBuilderTemplate = R"Py(
 def {0}({2}) -> {4}:
-  return _get_op_result_or_op_results({1}({3}))
+  return _maybe_cast(_get_op_result_or_op_results({1}({3})))
 )Py";
 
 static llvm::cl::OptionCategory
@@ -1004,8 +1013,8 @@ static void emitValueBuilder(const Operator &op,
                       llvm::join(valueBuilderParams, ", "),
                       llvm::join(opBuilderArgs, ", "),
                       (op.getNumResults() > 1
-                           ? "_Sequence[_ods_ir.OpResult]"
-                           : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+                           ? "_Sequence[_SubClassValueT]"
+                           : (op.getNumResults() > 0 ? "_SubClassValueT"
                                                      : "_ods_ir.Operation")));
 }
 

>From e9594f5503a11790f05c06db371838d9bab25a54 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 23 Oct 2023 10:07:12 -0500
Subject: [PATCH 3/5] add new line to op-python-bindings.td

---
 mlir/test/mlir-tblgen/op-python-bindings.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 96b0c170dc5bb40..9844040f8a33c4b 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
 }
 
 // CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)))
\ No newline at end of file
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)))

>From a20c21b2e75b7afde0bc8fb0c204e8cf0f9286b9 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 25 Oct 2023 14:23:52 -0500
Subject: [PATCH 4/5] WIP opresult and opoperand and blockarg casting

---
 .../mlir/Bindings/Python/PybindAdaptors.h     |  1 +
 mlir/lib/Bindings/Python/Globals.h            | 16 ++++
 mlir/lib/Bindings/Python/IRCore.cpp           | 52 +++++++++-
 mlir/lib/Bindings/Python/IRModule.cpp         | 47 +++++++++
 mlir/lib/Bindings/Python/IRModule.h           | 10 +-
 mlir/lib/Bindings/Python/MainModule.cpp       | 12 +++
 mlir/lib/Bindings/Python/PybindUtils.h        |  2 +-
 mlir/python/mlir/dialects/_ods_common.py      | 52 +---------
 mlir/python/mlir/ir.py                        | 16 +---
 mlir/test/mlir-tblgen/op-python-bindings.td   | 48 +++++-----
 mlir/test/python/dialects/arith_dialect.py    | 28 ++----
 mlir/test/python/ir/value.py                  | 96 +++++++++++++++++++
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp |  3 +-
 13 files changed, 261 insertions(+), 122 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 49680c8b79b135e..acc90e4ab9a22b8 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -234,6 +234,7 @@ struct type_caster<MlirValue> {
     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
         .attr("Value")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
         .release();
   };
 };
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 21899bdce22e810..98d3b16836e0d33 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -70,6 +70,13 @@ class PyGlobals {
   void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
                           bool replace = false);
 
+  /// Adds a user-friendly value caster. Raises an exception if the mapping
+  /// already exists and replace == false. This is intended to be called by
+  /// implementation code.
+  void registerValueCaster(MlirTypeID mlirTypeID,
+                           pybind11::function valueCaster,
+                           bool replace = false);
+
   /// Adds a concrete implementation dialect class.
   /// Raises an exception if the mapping already exists.
   /// This is intended to be called by implementation code.
@@ -90,6 +97,10 @@ class PyGlobals {
   std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
                                                      MlirDialect dialect);
 
+  /// Returns the custom value caster for MlirTypeID mlirTypeID.
+  std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
+                                                      MlirDialect dialect);
+
   /// Looks up a registered dialect class by namespace. Note that this may
   /// trigger loading of the defining module and can arbitrarily re-enter.
   std::optional<pybind11::object>
@@ -116,6 +127,11 @@ class PyGlobals {
   /// Cache for map of MlirTypeID to custom type caster.
   llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
 
+  /// Map of MlirTypeID to custom value caster.
+  llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
+  /// Cache for map of MlirTypeID to custom value caster.
+  llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMapCache;
+
   /// Set of dialect namespaces that we have attempted to import implementation
   /// modules for.
   llvm::StringSet<> loadedDialectModulesCache;
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7cfea31dbb2e80c..2c7ffda4e088032 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1899,13 +1899,26 @@ bool PyTypeID::operator==(const PyTypeID &other) const {
 }
 
 //------------------------------------------------------------------------------
-// PyValue and subclases.
+// PyValue and subclasses.
 //------------------------------------------------------------------------------
 
 pybind11::object PyValue::getCapsule() {
   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
 }
 
+pybind11::object PyValue::maybeDownCast() {
+  MlirType type = mlirValueGetType(get());
+  MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
+  assert(!mlirTypeIDIsNull(mlirTypeID) &&
+         "mlirTypeID was expected to be non-null.");
+  std::optional<pybind11::function> valueCaster =
+      PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
+  py::object this_ = py::cast(this, py::return_value_policy::move);
+  if (!valueCaster)
+    return this_;
+  return valueCaster.value()(this_);
+}
+
 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
   if (mlirValueIsNull(value))
@@ -2121,6 +2134,8 @@ class PyConcreteValue : public PyValue {
           return DerivedTy::isaFunction(otherValue);
         },
         py::arg("other_value"));
+    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+            [](DerivedTy &self) { return self.maybeDownCast(); });
     DerivedTy::bindDerived(cls);
   }
 
@@ -2193,6 +2208,7 @@ class PyBlockArgumentList
     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
 public:
   static constexpr const char *pyClassName = "BlockArgumentList";
+  using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
 
   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
                       intptr_t startIndex = 0, intptr_t length = -1,
@@ -2202,6 +2218,13 @@ class PyBlockArgumentList
                   step),
         operation(std::move(operation)), block(block) {}
 
+  pybind11::object getItem(intptr_t index) override {
+    auto item = this->SliceableT::getItem(index);
+    if (item.ptr() != nullptr)
+      return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
+    return item;
+  }
+
   static void bindDerived(ClassTy &c) {
     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
       return getValueTypes(self, self.operation->getContext());
@@ -2241,6 +2264,7 @@ class PyBlockArgumentList
 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
 public:
   static constexpr const char *pyClassName = "OpOperandList";
+  using SliceableT = Sliceable<PyOpOperandList, PyValue>;
 
   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
                   intptr_t length = -1, intptr_t step = 1)
@@ -2250,6 +2274,13 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
                   step),
         operation(operation) {}
 
+  pybind11::object getItem(intptr_t index) override {
+    auto item = this->SliceableT::getItem(index);
+    if (item.ptr() != nullptr)
+      return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
+    return item;
+  }
+
   void dunderSetItem(intptr_t index, PyValue value) {
     index = wrapIndex(index);
     mlirOperationSetOperand(operation->get(), index, value.get());
@@ -2296,6 +2327,7 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
 public:
   static constexpr const char *pyClassName = "OpResultList";
+  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
 
   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
                  intptr_t length = -1, intptr_t step = 1)
@@ -2303,7 +2335,14 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
                   length == -1 ? mlirOperationGetNumResults(operation->get())
                                : length,
                   step),
-        operation(operation) {}
+        operation(std::move(operation)) {}
+
+  pybind11::object getItem(intptr_t index) override {
+    auto item = this->SliceableT::getItem(index);
+    if (item.ptr() != nullptr)
+      return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
+    return item;
+  }
 
   static void bindDerived(ClassTy &c) {
     c.def_property_readonly("types", [](PyOpResultList &self) {
@@ -2891,8 +2930,9 @@ void mlir::python::populateIRCore(py::module &m) {
                    "single result)")
                       .str());
             }
-            return PyOpResult(operation.getRef(),
-                              mlirOperationGetResult(operation, 0));
+            PyOpResult result = PyOpResult(
+                operation.getRef(), mlirOperationGetResult(operation, 0));
+            return result.maybeDownCast();
           },
           "Shortcut to get an op result if it has only one (throws an error "
           "otherwise).")
@@ -3566,7 +3606,9 @@ void mlir::python::populateIRCore(py::module &m) {
           [](PyValue &self, PyValue &with) {
             mlirValueReplaceAllUsesOfWith(self.get(), with.get());
           },
-          kValueReplaceAllUsesWithDocstring);
+          kValueReplaceAllUsesWithDocstring)
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyValue &self) { return self.maybeDownCast(); });
   PyBlockArgument::bind(m);
   PyOpResult::bind(m);
   PyOpOperand::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index f8e22f7bb0c1ba7..0131e81d2c5d6ce 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -88,6 +88,19 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
   }
 }
 
+void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
+                                    pybind11::function valueCaster,
+                                    bool replace) {
+  pybind11::object &found = valueCasterMap[mlirTypeID];
+  if (found && !found.is_none() && !replace)
+    throw std::runtime_error("Value caster is already registered");
+  found = std::move(valueCaster);
+  const auto foundIt = valueCasterMapCache.find(mlirTypeID);
+  if (foundIt != valueCasterMapCache.end() && !foundIt->second.is_none()) {
+    valueCasterMapCache[mlirTypeID] = found;
+  }
+}
+
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
                                     py::object pyClass) {
   py::object &found = dialectClassMap[dialectNamespace];
@@ -163,6 +176,39 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
   }
 }
 
+std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
+                                                         MlirDialect dialect) {
+  {
+    // Fast match against the value caster map first (common case).
+    const auto foundIt = valueCasterMapCache.find(mlirTypeID);
+    if (foundIt != valueCasterMapCache.end()) {
+      if (foundIt->second.is_none())
+        return std::nullopt;
+      assert(foundIt->second && "py::function is defined");
+      return foundIt->second;
+    }
+  }
+
+  // Not found. Load the dialect namespace.
+  loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+
+  // Attempt to find from the canonical map and cache.
+  {
+    const auto foundIt = valueCasterMap.find(mlirTypeID);
+    if (foundIt != valueCasterMap.end()) {
+      if (foundIt->second.is_none())
+        return std::nullopt;
+      assert(foundIt->second && "py::object is defined");
+      // Positive cache.
+      valueCasterMapCache[mlirTypeID] = foundIt->second;
+      return foundIt->second;
+    }
+    // Negative cache.
+    valueCasterMap[mlirTypeID] = py::none();
+    return std::nullopt;
+  }
+}
+
 std::optional<py::object>
 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
   loadDialectModule(dialectNamespace);
@@ -218,4 +264,5 @@ void PyGlobals::clearImportCache() {
   loadedDialectModulesCache.clear();
   operationClassMapCache.clear();
   typeCasterMapCache.clear();
+  valueCasterMapCache.clear();
 }
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 01ee4975d0e9a91..b95c4578fbc2220 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -761,7 +761,7 @@ class PyRegion {
 
 /// Wrapper around an MlirAsmState.
 class PyAsmState {
- public:
+public:
   PyAsmState(MlirValue value, bool useLocalScope) {
     flags = mlirOpPrintingFlagsCreate();
     // The OpPrintingFlags are not exposed Python side, create locally and
@@ -780,16 +780,14 @@ class PyAsmState {
     state =
         mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
   }
-  ~PyAsmState() {
-    mlirOpPrintingFlagsDestroy(flags);
-  }
+  ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
   // Delete copy constructors.
   PyAsmState(PyAsmState &other) = delete;
   PyAsmState(const PyAsmState &other) = delete;
 
   MlirAsmState get() { return state; }
 
- private:
+private:
   MlirAsmState state;
   MlirOpPrintingFlags flags;
 };
@@ -1124,6 +1122,8 @@ class PyValue {
   /// Gets a capsule wrapping the void* within the MlirValue.
   pybind11::object getCapsule();
 
+  virtual pybind11::object maybeDownCast();
+
   /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
   /// the underlying MlirValue is still tied to the owning operation.
   static PyValue createFromCapsule(pybind11::object capsule);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index a936becf67bea75..aaa671fd82b6dbe 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -88,6 +88,18 @@ PYBIND11_MODULE(_mlir, m) {
       },
       "typeid"_a, "type_caster"_a, "replace"_a = false,
       "Register a type caster for casting MLIR types to custom user types.");
+  m.def(
+      "register_value_caster",
+      [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
+        return py::cpp_function(
+            [mlirTypeID, replace](py::object valueCaster) -> py::object {
+              PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
+                                                   replace);
+              return valueCaster;
+            });
+      },
+      "typeid"_a, "replace"_a = false,
+      "Register a value caster for casting MLIR values to custom user values.");
 
   // Define and populate IR submodule.
   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 2a8da20bee0495d..efb7b713f80a40c 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -231,7 +231,7 @@ class Sliceable {
   /// Returns the element at the given slice index. Supports negative indices
   /// by taking elements in inverse order. Returns a nullptr object if out
   /// of bounds.
-  pybind11::object getItem(intptr_t index) {
+  virtual pybind11::object getItem(intptr_t index) {
     // Negative indices mean we count from the end.
     index = wrapIndex(index);
     if (index < 0) {
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index dd41ee63c8bf7af..fa73c197c17faf6 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -7,7 +7,6 @@
 # with a relative import.
 from .._mlir_libs import _mlir as _cext
 from typing import (
-    Callable as _Callable,
     Sequence as _Sequence,
     Type as _Type,
     TypeVar as _TypeVar,
@@ -132,50 +131,7 @@ def get_op_result_or_op_results(
     )
 
 
-U = _TypeVar("U", bound=_cext.ir.Value)
-SubClassValueT = _Type[U]
-
-ValueCasterT = _Callable[
-    [_Union[_cext.ir.Value, _cext.ir.OpResult]], _Union[SubClassValueT, None]
-]
-
-_VALUE_CASTERS: defaultdict[
-    _cext.ir.TypeID,
-    _Sequence[ValueCasterT],
-] = defaultdict(list)
-
-
-def has_value_caster(typeid: _cext.ir.TypeID):
-    if not isinstance(typeid, _cext.ir.TypeID):
-        raise ValueError(f"{typeid=} is not a TypeID")
-    if typeid in _VALUE_CASTERS:
-        return True
-    return False
-
-
-def get_value_caster(typeid: _cext.ir.TypeID):
-    if not has_value_caster(typeid):
-        raise ValueError(f"no registered caster for {typeid=}")
-    return _VALUE_CASTERS[typeid]
-
-
-def maybe_cast(
-    val: _Union[
-        _cext.ir.Value,
-        _cext.ir.OpResult,
-        _Sequence[_cext.ir.Value],
-        _Sequence[_cext.ir.OpResult],
-        _cext.ir.Operation,
-    ]
-) -> _Union[SubClassValueT, _Sequence[SubClassValueT], _cext.ir.Operation]:
-    if isinstance(val, (tuple, list)):
-        return tuple(map(maybe_cast, val))
-
-    if not isinstance(val, _cext.ir.Value) and not isinstance(val, _cext.ir.OpResult):
-        return val
-
-    if has_value_caster(val.type.typeid):
-        for caster in get_value_caster(val.type.typeid):
-            if casted := caster(val):
-                return casted
-    return val
+# This is the standard way to indicate subclass/inheritance relationship
+# see the typing.Type doc string.
+_U = _TypeVar("_U", bound=_cext.ir.Value)
+SubClassValueT = _Type[_U]
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 6e1f2b357f31711..eede64de674e22b 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,21 +4,7 @@
 
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
-from ._mlir_libs._mlir import register_type_caster
-from .dialects._ods_common import ValueCasterT, _VALUE_CASTERS
-
-
-def register_value_caster(typeid: TypeID, priority: int = None):
-    def wrapper(caster: ValueCasterT):
-        if not isinstance(typeid, TypeID):
-            raise ValueError(f"{typeid=} is not a TypeID")
-        if priority is None:
-            _VALUE_CASTERS[typeid].append(caster)
-        else:
-            _VALUE_CASTERS[typeid].insert(priority, caster)
-        return caster
-
-    return wrapper
+from ._mlir_libs._mlir import register_type_caster, register_value_caster
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 9844040f8a33c4b..f7df8ba2df0ae2f 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -61,7 +61,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
 }
 
 // CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -108,7 +108,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
 }
 
 // CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -158,7 +158,7 @@ def AttributedOp : TestOp<"attributed_op"> {
 }
 
 // CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK:     return _maybe_cast(_get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)))
+// CHECK:     return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -194,7 +194,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
 }
 
 // CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -218,7 +218,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
 }
 
 // CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -236,7 +236,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
 }
 
 // CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -246,7 +246,7 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
 }
 
 // CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class EmptyOp(_ods_ir.OpView):
@@ -263,7 +263,7 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
 // CHECK: def empty(*, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -276,7 +276,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
 }
 
 // CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -289,7 +289,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
 }
 
 // CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -327,7 +327,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
 }
 
 // CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -358,7 +358,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
 }
 
 // CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -390,7 +390,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
 }
 
 // CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -423,7 +423,7 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
 }
 
 // CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class PythonKeywordOp(_ods_ir.OpView):
@@ -447,7 +447,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
 }
 
 // CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results"
 def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -461,7 +461,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
 }
 
 // CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -471,7 +471,7 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
 }
 
 // CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -498,7 +498,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
 }
 
 // CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -524,7 +524,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
 }
 
 // CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SimpleOp(_ods_ir.OpView):
@@ -564,7 +564,7 @@ def SimpleOp : TestOp<"simple"> {
 }
 
 // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
 
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -591,7 +591,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
 }
 
 // CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
 
 // CHECK: class VariadicRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -614,7 +614,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
 }
 
 // CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -623,7 +623,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
 }
 
 // CHECK: def _123with__special_characters(*, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
 }
 
 // CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK:   return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)))
+// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 180d30ff4cfb3e5..39c3d5799a6563a 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -3,7 +3,7 @@
 
 from mlir.ir import *
 import mlir.dialects.arith as arith
-from mlir.dialects._ods_common import maybe_cast
+import mlir.dialects.func as func
 
 
 def run(f):
@@ -49,31 +49,22 @@ def _binary_op(lhs, rhs, op: str):
             raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
 
         op = getattr(arith, f"{op}Op")
-        return maybe_cast(op(lhs, rhs).result)
+        return op(lhs, rhs).result
 
     @register_value_caster(F16Type.static_typeid)
     @register_value_caster(F32Type.static_typeid)
     @register_value_caster(F64Type.static_typeid)
     @register_value_caster(IntegerType.static_typeid)
     class ArithValue(Value):
+        def __init__(self, v):
+            super().__init__(v)
+
         __add__ = partialmethod(_binary_op, op="add")
         __sub__ = partialmethod(_binary_op, op="sub")
         __mul__ = partialmethod(_binary_op, op="mul")
 
         def __str__(self):
-            return super().__str__().replace("Value", "ArithValue")
-
-    @register_value_caster(IntegerType.static_typeid, priority=0)
-    class ArithValue1(Value):
-        __mul__ = partialmethod(_binary_op, op="mul")
-
-        def __str__(self):
-            return super().__str__().replace("Value", "ArithValue1")
-
-    @register_value_caster(IntegerType.static_typeid, priority=0)
-    def no_op_caster(val):
-        print("no_op_caster", val)
-        return None
+            return super().__str__().replace(Value.__name__, ArithValue.__name__)
 
     with Context() as ctx, Location.unknown():
         module = Module.create()
@@ -97,10 +88,3 @@ def no_op_caster(val):
             b = a * a
             # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
             print(b)
-
-            # CHECK: no_op_caster Value(%c1_i32 = arith.constant 1 : i32)
-            a = arith.constant(value=IntegerAttr.get(i32, 1))
-            b = a * a
-            # CHECK: no_op_caster Value(%3 = arith.muli %c1_i32, %c1_i32 : i32)
-            # CHECK: ArithValue1(%3 = arith.muli %c1_i32, %c1_i32 : i32)
-            print(b)
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index ddf653dcce27804..1c3e1a6ae9654fe 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -270,3 +270,99 @@ def testValueSetType():
 
             # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
             print(value.owner)
+
+
+# CHECK-LABEL: TEST: testValueCasters
+ at run
+def testValueCasters():
+    class NOPResult(OpResult):
+        def __init__(self, v):
+            super().__init__(v)
+
+        def __str__(self):
+            return super().__str__().replace(Value.__name__, NOPResult.__name__)
+
+    class NOPValue(Value):
+        def __init__(self, v):
+            super().__init__(v)
+
+        def __str__(self):
+            return super().__str__().replace(Value.__name__, NOPValue.__name__)
+
+    class NOPBlockArg(BlockArgument):
+        def __init__(self, v):
+            super().__init__(v)
+
+        def __str__(self):
+            return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
+
+    @register_value_caster(IntegerType.static_typeid)
+    def cast_int(v):
+        print("in caster", v.__class__.__name__)
+        if isinstance(v, OpResult):
+            return NOPResult(v)
+        if isinstance(v, BlockArgument):
+            return NOPBlockArg(v)
+        elif isinstance(v, Value):
+            return NOPValue(v)
+
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            values = Operation.create("custom.op1", results=[i32, i32]).results
+            # CHECK: in caster OpResult
+            # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", values[0].result_number, values[0])
+            # CHECK: in caster OpResult
+            # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", values[1].result_number, values[0])
+
+            value0, value1 = values
+            # CHECK: in caster OpResult
+            # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", value0.result_number, values[0])
+            # CHECK: in caster OpResult
+            # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", value1.result_number, values[0])
+
+            op1 = Operation.create("custom.op2", operands=[value0, value1])
+            # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
+            print(op1)
+
+            # CHECK: in caster Value
+            # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("operand 0", op1.operands[0])
+            # CHECK: in caster Value
+            # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("operand 1", op1.operands[1])
+
+            # CHECK: in caster BlockArgument
+            # CHECK: in caster BlockArgument
+            @func.FuncOp.from_py_func(i32, i32)
+            def reduction(arg0, arg1):
+                # CHECK: as func arg 0 NOPBlockArg
+                print("as func arg", arg0.arg_number, arg0.__class__.__name__)
+                # CHECK: as func arg 1 NOPBlockArg
+                print("as func arg", arg1.arg_number, arg1.__class__.__name__)
+
+    @register_value_caster(IntegerType.static_typeid, replace=True)
+    def dont_cast_int(v):
+        print("don't cast", v.result_number, v)
+        return v
+
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32)
+            new_value = Operation.create("custom.op1", results=[i32]).result
+            # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32)
+            print("result", new_value.result_number, new_value)
+
+            # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32)
+            new_value = Operation.create("custom.op2", results=[i32]).results[0]
+            # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32)
+            print("result", new_value.result_number, new_value)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 170ac6b87c693d7..0c0ad2cfeffdcc2 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -37,7 +37,6 @@ from ._ods_common import (
     get_op_result_or_op_results as _get_op_result_or_op_results,
     get_op_result_or_value as _get_op_result_or_value,
     get_op_results_or_values as _get_op_results_or_values,
-    maybe_cast as _maybe_cast,
     segmented_accessor as _ods_segmented_accessor,
 )
 _ods_ir = _ods_cext.ir
@@ -272,7 +271,7 @@ constexpr const char *regionAccessorTemplate = R"Py(
 
 constexpr const char *valueBuilderTemplate = R"Py(
 def {0}({2}) -> {4}:
-  return _maybe_cast(_get_op_result_or_op_results({1}({3})))
+  return _get_op_result_or_op_results({1}({3}))
 )Py";
 
 static llvm::cl::OptionCategory

>From 0e8da418fe0eb0e38356a43fc0ee867d3cb65ed1 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 30 Oct 2023 13:34:21 -0500
Subject: [PATCH 5/5] done with opresult, blockarg casting

---
 mlir/include/mlir-c/Bindings/Python/Interop.h | 18 ++++++++-
 mlir/lib/Bindings/Python/IRCore.cpp           | 12 +++---
 mlir/lib/Bindings/Python/MainModule.cpp       |  2 +-
 mlir/python/mlir/dialects/_ods_common.py      |  1 -
 mlir/test/python/dialects/arith_dialect.py    |  3 +-
 mlir/test/python/dialects/python_test.py      |  6 +++
 mlir/test/python/ir/value.py                  |  5 ++-
 mlir/test/python/lib/PythonTestModule.cpp     | 40 +++++++++++++++----
 8 files changed, 66 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index f79c10cb9383829..9b026a6b922de47 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -121,10 +121,26 @@
  *   def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
  *                              bool replace)
  * where replace indicates the typeCaster should replace any existing registered
- * type casters (such as those for upstream ConcreteTypes).
+ * type casters (such as those for upstream ConcreteTypes). The interface of the
+ * typeCaster is:
+ *   def type_caster(ir.Type) -> SubClassTypeT
+ * where SubClassTypeT indicates the result should be a subclass (inherit from)
+ * ir.Type.
  */
 #define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
 
+/** Attribute on main C extension module (_mlir) that corresponds to the
+ * value caster registration binding. The signature of the function is:
+ *   def register_value_caster(MlirTypeID mlirTypeID, bool replace,
+ *                              py::function valueCaster)
+ * where replace indicates the valueCaster should replace any existing
+ * registered value casters. The interface of the valueCaster is:
+ *   def value_caster(ir.Value) -> SubClassValueT
+ * where SubClassValueT indicates the result should be a subclass (inherit from)
+ * ir.Value.
+ */
+#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster"
+
 /// Gets a void* from a wrapped struct. Needed because const cast is different
 /// between C/C++.
 #ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2c7ffda4e088032..53eb75f810c1845 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1913,10 +1913,10 @@ pybind11::object PyValue::maybeDownCast() {
          "mlirTypeID was expected to be non-null.");
   std::optional<pybind11::function> valueCaster =
       PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
-  py::object this_ = py::cast(this, py::return_value_policy::move);
+  py::object thisObj = py::cast(this, py::return_value_policy::move);
   if (!valueCaster)
-    return this_;
-  return valueCaster.value()(this_);
+    return thisObj;
+  return valueCaster.value()(thisObj);
 }
 
 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
@@ -2930,9 +2930,9 @@ void mlir::python::populateIRCore(py::module &m) {
                    "single result)")
                       .str());
             }
-            PyOpResult result = PyOpResult(
-                operation.getRef(), mlirOperationGetResult(operation, 0));
-            return result.maybeDownCast();
+            return PyOpResult(operation.getRef(),
+                              mlirOperationGetResult(operation, 0))
+                .maybeDownCast();
           },
           "Shortcut to get an op result if it has only one (throws an error "
           "otherwise).")
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index aaa671fd82b6dbe..10287e48d6814f2 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -89,7 +89,7 @@ PYBIND11_MODULE(_mlir, m) {
       "typeid"_a, "type_caster"_a, "replace"_a = false,
       "Register a type caster for casting MLIR types to custom user types.");
   m.def(
-      "register_value_caster",
+      MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
       [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
         return py::cpp_function(
             [mlirTypeID, replace](py::object valueCaster) -> py::object {
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index fa73c197c17faf6..60ce83c09f1717e 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -1,7 +1,6 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from collections import defaultdict
 
 # Provide a convenient name for sub-packages to resolve the main C-extension
 # with a relative import.
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 39c3d5799a6563a..c8d21dfb62ed557 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -39,7 +39,7 @@ def testFastMathFlags():
 # CHECK-LABEL: TEST: testArithValue
 @run
 def testArithValue():
-    def _binary_op(lhs, rhs, op: str):
+    def _binary_op(lhs, rhs, op: str) -> "ArithValue":
         op = op.capitalize()
         if arith._is_float_type(lhs.type):
             op += "F"
@@ -71,7 +71,6 @@ def __str__(self):
         f16_t = F16Type.get()
         f32_t = F32Type.get()
         f64_t = F64Type.get()
-        i32 = IntegerType.get_signless(32)
 
         with InsertionPoint(module.body):
             a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index a70b6fd5e5e4d84..c5e39c0354c191d 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -426,6 +426,12 @@ def __str__(self):
             # And it should be equal to the in-tree concrete type
             assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
 
+            d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result
+            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>)
+            print(d)
+            # CHECK: <importlib._bootstrap.TestTensorValue object at
+            print(repr(d))
+
 
 # CHECK-LABEL: TEST: inferReturnTypeComponents
 @run
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 1c3e1a6ae9654fe..873c20b329bc46b 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -3,6 +3,7 @@
 import gc
 from mlir.ir import *
 from mlir.dialects import func
+from mlir.dialects._ods_common import SubClassValueT
 
 
 def run(f):
@@ -297,7 +298,7 @@ def __str__(self):
             return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
 
     @register_value_caster(IntegerType.static_typeid)
-    def cast_int(v):
+    def cast_int(v) -> SubClassValueT:
         print("in caster", v.__class__.__name__)
         if isinstance(v, OpResult):
             return NOPResult(v)
@@ -349,7 +350,7 @@ def reduction(arg0, arg1):
                 print("as func arg", arg1.arg_number, arg1.__class__.__name__)
 
     @register_value_caster(IntegerType.static_typeid, replace=True)
-    def dont_cast_int(v):
+    def dont_cast_int(v) -> Value:
         print("don't cast", v.result_number, v)
         return v
 
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f533082a0a147c0..1e584343d0f0a85 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -42,6 +42,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
             return cls(mlirPythonTestTestAttributeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
+
   mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
                      mlirPythonTestTestTypeGetTypeID)
       .def_classmethod(
@@ -50,7 +51,8 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
             return cls(mlirPythonTestTestTypeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
-  auto cls =
+
+  auto typeCls =
       mlir_type_subclass(m, "TestIntegerRankedTensorType",
                          mlirTypeIsARankedIntegerTensor,
                          py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
@@ -65,16 +67,38 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
                     encoding));
               },
               "cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
-  assert(py::hasattr(cls.get_class(), "static_typeid") &&
+
+  assert(py::hasattr(typeCls.get_class(), "static_typeid") &&
          "TestIntegerRankedTensorType has no static_typeid");
-  MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID();
+
+  MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
+
   py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
       .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
-          mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) {
-            return cls.get_class()(mlirType);
+          mlirRankedTensorTypeID,
+          pybind11::cpp_function([typeCls](const py::object &mlirType) {
+            return typeCls.get_class()(mlirType);
           }),
           /*replace=*/true);
-  mlir_value_subclass(m, "TestTensorValue",
-                      mlirTypeIsAPythonTestTestTensorValue)
-      .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+
+  auto valueCls = mlir_value_subclass(m, "TestTensorValue",
+                                      mlirTypeIsAPythonTestTestTensorValue)
+                      .def("is_null", [](MlirValue &self) {
+                        return mlirValueIsNull(self);
+                      });
+
+  py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+      .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
+          mlirRankedTensorTypeID)(
+          pybind11::cpp_function([valueCls](const py::object &valueObj) {
+            py::object capsule = mlirApiObjectToCapsule(valueObj);
+            MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
+            MlirType t = mlirValueGetType(v);
+            if (mlirShapedTypeHasStaticShape(t) &&
+                mlirShapedTypeGetDimSize(t, 0) == 1 &&
+                mlirShapedTypeGetDimSize(t, 1) == 2 &&
+                mlirShapedTypeGetDimSize(t, 2) == 3)
+              return valueCls.get_class()(valueObj);
+            return valueObj;
+          }));
 }



More information about the Mlir-commits mailing list