[Mlir-commits] [mlir] 7c85086 - [mlir][python] value casting (#69644)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 7 08:49:46 PST 2023


Author: Maksim Levental
Date: 2023-11-07T10:49:41-06:00
New Revision: 7c850867b9ef4427375da6d83c34d0b9c944fcb8

URL: https://github.com/llvm/llvm-project/commit/7c850867b9ef4427375da6d83c34d0b9c944fcb8
DIFF: https://github.com/llvm/llvm-project/commit/7c850867b9ef4427375da6d83c34d0b9c944fcb8.diff

LOG: [mlir][python] value casting (#69644)

This PR adds "value casting", i.e., a mechanism to wrap `ir.Value` in a
proxy class that overloads dunders such as `__add__`, `__sub__`, and
`__mul__` for fun and great profit.

This is thematically similar to
https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb
and
https://github.com/llvm/llvm-project/commit/9566ee280607d91fa2e5eca730a6765ac84dfd0f.
The example in the test demonstrates the value of the feature (no pun
intended):

```python
    @register_value_caster(F16Type.static_typeid)
    @register_value_caster(F32Type.static_typeid)
    @register_value_caster(F64Type.static_typeid)
    @register_value_caster(IntegerType.static_typeid)
    class ArithValue(Value):
        __add__ = partialmethod(_binary_op, op="add")
        __sub__ = partialmethod(_binary_op, op="sub")
        __mul__ = partialmethod(_binary_op, op="mul")

    a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
    b = a + a
    # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
    print(b)

    a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
    b = a - a
    # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
    print(b)

    a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
    b = a * a
    # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
    print(b)
```

**EDIT**: this now goes through the bindings and thus supports automatic
casting of `OpResult` (including as an element of `OpResultList`),
`BlockArgument` (including as an element of `BlockArgumentList`), as
well as `Value`.

Added: 
    

Modified: 
    mlir/include/mlir-c/Bindings/Python/Interop.h
    mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    mlir/lib/Bindings/Python/Globals.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/python/mlir/dialects/_ods_common.py
    mlir/python/mlir/ir.py
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/test/python/dialects/arith_dialect.py
    mlir/test/python/dialects/python_test.py
    mlir/test/python/ir/value.py
    mlir/test/python/lib/PythonTestModule.cpp
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index f79c10cb9383829..0a36e97c2ae6831 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -118,13 +118,28 @@
 
 /** Attribute on main C extension module (_mlir) that corresponds to the
  * type caster registration binding. The signature of the function is:
- *   def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
- *                              bool replace)
- * where replace indicates the typeCaster should replace any existing registered
- * type casters (such as those for upstream ConcreteTypes).
+ *   def register_type_caster(MlirTypeID mlirTypeID, *, bool replace)
+ * which then takes a typeCaster (register_type_caster is meant to be used as a
+ * decorator from python), and where replace indicates the typeCaster should
+ * replace any existing registered type casters (such as those for upstream
+ * ConcreteTypes). The interface of the typeCaster is: def type_caster(ir.Type)
+ * -> SubClassTypeT where SubClassTypeT indicates the result should be a
+ * subclass (inherit from) ir.Type.
  */
 #define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
 
+/** Attribute on main C extension module (_mlir) that corresponds to the
+ * value caster registration binding. The signature of the function is:
+ *   def register_value_caster(MlirTypeID mlirTypeID, *, bool replace)
+ * which then takes a valueCaster (register_value_caster is meant to be used as
+ * a decorator, from python), and where replace indicates the valueCaster should
+ * replace any existing registered value casters. The interface of the
+ * valueCaster is: def value_caster(ir.Value) -> SubClassValueT where
+ * SubClassValueT indicates the result should be a subclass (inherit from)
+ * ir.Value.
+ */
+#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster"
+
 /// Gets a void* from a wrapped struct. Needed because const cast is 
diff erent
 /// between C/C++.
 #ifdef __cplusplus

diff  --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 49680c8b79b135e..5e0e56fc00a6736 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -234,6 +234,7 @@ struct type_caster<MlirValue> {
     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
         .attr("Value")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
         .release();
   };
 };
@@ -496,11 +497,10 @@ class mlir_type_subclass : public pure_subclass {
     if (getTypeIDFunction) {
       py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
           .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
-              getTypeIDFunction(),
-              pybind11::cpp_function(
-                  [thisClass = thisClass](const py::object &mlirType) {
-                    return thisClass(mlirType);
-                  }));
+              getTypeIDFunction())(pybind11::cpp_function(
+              [thisClass = thisClass](const py::object &mlirType) {
+                return thisClass(mlirType);
+              }));
     }
   }
 };

diff  --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 976297257ced06e..a022067f5c7e575 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -66,6 +66,13 @@ class PyGlobals {
   void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
                           bool replace = false);
 
+  /// Adds a user-friendly value caster. Raises an exception if the mapping
+  /// already exists and replace == false. This is intended to be called by
+  /// implementation code.
+  void registerValueCaster(MlirTypeID mlirTypeID,
+                           pybind11::function valueCaster,
+                           bool replace = false);
+
   /// Adds a concrete implementation dialect class.
   /// Raises an exception if the mapping already exists.
   /// This is intended to be called by implementation code.
@@ -86,6 +93,10 @@ class PyGlobals {
   std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
                                                      MlirDialect dialect);
 
+  /// Returns the custom value caster for MlirTypeID mlirTypeID.
+  std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
+                                                      MlirDialect dialect);
+
   /// Looks up a registered dialect class by namespace. Note that this may
   /// trigger loading of the defining module and can arbitrarily re-enter.
   std::optional<pybind11::object>
@@ -109,7 +120,8 @@ class PyGlobals {
   llvm::StringMap<pybind11::object> attributeBuilderMap;
   /// Map of MlirTypeID to custom type caster.
   llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
-
+  /// Map of MlirTypeID to custom value caster.
+  llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
   /// Set of dialect namespaces that we have attempted to import implementation
   /// modules for.
   llvm::StringSet<> loadedDialectModules;

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7cfea31dbb2e80c..0f2ca666ccc050e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1899,13 +1899,28 @@ bool PyTypeID::operator==(const PyTypeID &other) const {
 }
 
 //------------------------------------------------------------------------------
-// PyValue and subclases.
+// PyValue and subclasses.
 //------------------------------------------------------------------------------
 
 pybind11::object PyValue::getCapsule() {
   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
 }
 
+pybind11::object PyValue::maybeDownCast() {
+  MlirType type = mlirValueGetType(get());
+  MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
+  assert(!mlirTypeIDIsNull(mlirTypeID) &&
+         "mlirTypeID was expected to be non-null.");
+  std::optional<pybind11::function> valueCaster =
+      PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
+  // py::return_value_policy::move means use std::move to move the return value
+  // contents into a new instance that will be owned by Python.
+  py::object thisObj = py::cast(this, py::return_value_policy::move);
+  if (!valueCaster)
+    return thisObj;
+  return valueCaster.value()(thisObj);
+}
+
 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
   if (mlirValueIsNull(value))
@@ -2121,6 +2136,8 @@ class PyConcreteValue : public PyValue {
           return DerivedTy::isaFunction(otherValue);
         },
         py::arg("other_value"));
+    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+            [](DerivedTy &self) { return self.maybeDownCast(); });
     DerivedTy::bindDerived(cls);
   }
 
@@ -2193,6 +2210,7 @@ class PyBlockArgumentList
     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
 public:
   static constexpr const char *pyClassName = "BlockArgumentList";
+  using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
 
   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
                       intptr_t startIndex = 0, intptr_t length = -1,
@@ -2241,6 +2259,7 @@ class PyBlockArgumentList
 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
 public:
   static constexpr const char *pyClassName = "OpOperandList";
+  using SliceableT = Sliceable<PyOpOperandList, PyValue>;
 
   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
                   intptr_t length = -1, intptr_t step = 1)
@@ -2296,6 +2315,7 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
 public:
   static constexpr const char *pyClassName = "OpResultList";
+  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
 
   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
                  intptr_t length = -1, intptr_t step = 1)
@@ -2303,7 +2323,7 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
                   length == -1 ? mlirOperationGetNumResults(operation->get())
                                : length,
                   step),
-        operation(operation) {}
+        operation(std::move(operation)) {}
 
   static void bindDerived(ClassTy &c) {
     c.def_property_readonly("types", [](PyOpResultList &self) {
@@ -2892,7 +2912,8 @@ void mlir::python::populateIRCore(py::module &m) {
                       .str());
             }
             return PyOpResult(operation.getRef(),
-                              mlirOperationGetResult(operation, 0));
+                              mlirOperationGetResult(operation, 0))
+                .maybeDownCast();
           },
           "Shortcut to get an op result if it has only one (throws an error "
           "otherwise).")
@@ -3566,7 +3587,9 @@ void mlir::python::populateIRCore(py::module &m) {
           [](PyValue &self, PyValue &with) {
             mlirValueReplaceAllUsesOfWith(self.get(), with.get());
           },
-          kValueReplaceAllUsesWithDocstring);
+          kValueReplaceAllUsesWithDocstring)
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyValue &self) { return self.maybeDownCast(); });
   PyBlockArgument::bind(m);
   PyOpResult::bind(m);
   PyOpOperand::bind(m);

diff  --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 6c5cde86236ce90..5538924d2481849 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -88,6 +88,16 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
   found = std::move(typeCaster);
 }
 
+void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
+                                    pybind11::function valueCaster,
+                                    bool replace) {
+  pybind11::object &found = valueCasterMap[mlirTypeID];
+  if (found && !replace)
+    throw std::runtime_error("Value caster is already registered: " +
+                             py::repr(found).cast<std::string>());
+  found = std::move(valueCaster);
+}
+
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
                                     py::object pyClass) {
   py::object &found = dialectClassMap[dialectNamespace];
@@ -134,6 +144,17 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
   return std::nullopt;
 }
 
+std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
+                                                         MlirDialect dialect) {
+  loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+  const auto foundIt = valueCasterMap.find(mlirTypeID);
+  if (foundIt != valueCasterMap.end()) {
+    assert(foundIt->second && "value caster is defined");
+    return foundIt->second;
+  }
+  return std::nullopt;
+}
+
 std::optional<py::object>
 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
   // Make sure dialect module is loaded.

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 01ee4975d0e9a91..af55693f18fbbf9 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -761,7 +761,7 @@ class PyRegion {
 
 /// Wrapper around an MlirAsmState.
 class PyAsmState {
- public:
+public:
   PyAsmState(MlirValue value, bool useLocalScope) {
     flags = mlirOpPrintingFlagsCreate();
     // The OpPrintingFlags are not exposed Python side, create locally and
@@ -780,16 +780,14 @@ class PyAsmState {
     state =
         mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
   }
-  ~PyAsmState() {
-    mlirOpPrintingFlagsDestroy(flags);
-  }
+  ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
   // Delete copy constructors.
   PyAsmState(PyAsmState &other) = delete;
   PyAsmState(const PyAsmState &other) = delete;
 
   MlirAsmState get() { return state; }
 
- private:
+private:
   MlirAsmState state;
   MlirOpPrintingFlags flags;
 };
@@ -1112,6 +1110,10 @@ class PyConcreteAttribute : public BaseTy {
 /// bindings so such operation always exists).
 class PyValue {
 public:
+  // The virtual here is "load bearing" in that it enables RTTI
+  // for PyConcreteValue CRTP classes that support maybeDownCast.
+  // See PyValue::maybeDownCast.
+  virtual ~PyValue() = default;
   PyValue(PyOperationRef parentOperation, MlirValue value)
       : parentOperation(std::move(parentOperation)), value(value) {}
   operator MlirValue() const { return value; }
@@ -1124,6 +1126,8 @@ class PyValue {
   /// Gets a capsule wrapping the void* within the MlirValue.
   pybind11::object getCapsule();
 
+  pybind11::object maybeDownCast();
+
   /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
   /// the underlying MlirValue is still tied to the owning operation.
   static PyValue createFromCapsule(pybind11::object capsule);

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 2ba3a3677198cbc..17272472ccca42a 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,8 +12,6 @@
 #include "IRModule.h"
 #include "Pass.h"
 
-#include <tuple>
-
 namespace py = pybind11;
 using namespace mlir;
 using namespace py::literals;
@@ -46,7 +44,8 @@ PYBIND11_MODULE(_mlir, m) {
            "dialect_namespace"_a, "dialect_class"_a,
            "Testing hook for directly registering a dialect")
       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
-           "operation_name"_a, "operation_class"_a, "replace"_a = false,
+           "operation_name"_a, "operation_class"_a, py::kw_only(),
+           "replace"_a = false,
            "Testing hook for directly registering an operation");
 
   // Aside from making the globals accessible to python, having python manage
@@ -82,17 +81,32 @@ PYBIND11_MODULE(_mlir, m) {
               return opClass;
             });
       },
-      "dialect_class"_a, "replace"_a = false,
+      "dialect_class"_a, py::kw_only(), "replace"_a = false,
       "Produce a class decorator for registering an Operation class as part of "
       "a dialect");
   m.def(
       MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
-      [](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
-        PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
-                                            replace);
+      [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
+        return py::cpp_function([mlirTypeID,
+                                 replace](py::object typeCaster) -> py::object {
+          PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
+          return typeCaster;
+        });
       },
-      "typeid"_a, "type_caster"_a, "replace"_a = false,
+      "typeid"_a, py::kw_only(), "replace"_a = false,
       "Register a type caster for casting MLIR types to custom user types.");
+  m.def(
+      MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
+      [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
+        return py::cpp_function(
+            [mlirTypeID, replace](py::object valueCaster) -> py::object {
+              PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
+                                                   replace);
+              return valueCaster;
+            });
+      },
+      "typeid"_a, py::kw_only(), "replace"_a = false,
+      "Register a value caster for casting MLIR values to custom user values.");
 
   // Define and populate IR submodule.
   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 2a8da20bee0495d..38462ac8ba6db9c 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -10,6 +10,7 @@
 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
 
 #include "mlir-c/Support.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/DataTypes.h"
 
@@ -228,6 +229,11 @@ class Sliceable {
     return linearIndex;
   }
 
+  /// Trait to check if T provides a `maybeDownCast` method.
+  /// Note, you need the & to detect inherited members.
+  template <typename T, typename... Args>
+  using has_maybe_downcast = decltype(&T::maybeDownCast);
+
   /// Returns the element at the given slice index. Supports negative indices
   /// by taking elements in inverse order. Returns a nullptr object if out
   /// of bounds.
@@ -239,8 +245,13 @@ class Sliceable {
       return {};
     }
 
-    return pybind11::cast(
-        static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
+    if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
+      return static_cast<Derived *>(this)
+          ->getRawElement(linearizeIndex(index))
+          .maybeDownCast();
+    else
+      return pybind11::cast(
+          static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
   }
 
   /// Returns a new instance of the pseudo-container restricted to the given

diff  --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 9cca7d659ec8cb3..60ce83c09f1717e 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -5,7 +5,12 @@
 # Provide a convenient name for sub-packages to resolve the main C-extension
 # with a relative import.
 from .._mlir_libs import _mlir as _cext
-from typing import Sequence as _Sequence, Union as _Union
+from typing import (
+    Sequence as _Sequence,
+    Type as _Type,
+    TypeVar as _TypeVar,
+    Union as _Union,
+)
 
 __all__ = [
     "equally_sized_accessor",
@@ -123,3 +128,9 @@ def get_op_result_or_op_results(
         if len(op.results) > 0
         else op
     )
+
+
+# This is the standard way to indicate subclass/inheritance relationship
+# see the typing.Type doc string.
+_U = _TypeVar("_U", bound=_cext.ir.Value)
+SubClassValueT = _Type[_U]

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index cf4228c2a63a91b..18526ab8c3c02dc 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,7 +4,7 @@
 
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
-from ._mlir_libs._mlir import register_type_caster
+from ._mlir_libs._mlir import register_type_caster, register_value_caster
 
 
 # Convenience decorator for registering user-friendly Attribute builders.

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 63dad1cc901fe2b..f7df8ba2df0ae2f 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
 }
 
 // CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
+// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))

diff  --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 6d1c5eab7589847..f80f2c084a0f3b8 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -1,8 +1,9 @@
 # RUN: %PYTHON %s | FileCheck %s
+from functools import partialmethod
 
 from mlir.ir import *
-import mlir.dialects.func as func
 import mlir.dialects.arith as arith
+import mlir.dialects.func as func
 
 
 def run(f):
@@ -35,14 +36,59 @@ def testFastMathFlags():
             print(r)
 
 
-# CHECK-LABEL: TEST: testArithValueBuilder
+# CHECK-LABEL: TEST: testArithValue
 @run
-def testArithValueBuilder():
+def testArithValue():
+    def _binary_op(lhs, rhs, op: str) -> "ArithValue":
+        op = op.capitalize()
+        if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+            op += "F"
+        elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
+            lhs.type
+        ):
+            op += "I"
+        else:
+            raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
+
+        op = getattr(arith, f"{op}Op")
+        return op(lhs, rhs).result
+
+    @register_value_caster(F16Type.static_typeid)
+    @register_value_caster(F32Type.static_typeid)
+    @register_value_caster(F64Type.static_typeid)
+    @register_value_caster(IntegerType.static_typeid)
+    class ArithValue(Value):
+        def __init__(self, v):
+            super().__init__(v)
+
+        __add__ = partialmethod(_binary_op, op="add")
+        __sub__ = partialmethod(_binary_op, op="sub")
+        __mul__ = partialmethod(_binary_op, op="mul")
+
+        def __str__(self):
+            return super().__str__().replace(Value.__name__, ArithValue.__name__)
+
     with Context() as ctx, Location.unknown():
         module = Module.create()
+        f16_t = F16Type.get()
         f32_t = F32Type.get()
+        f64_t = F64Type.get()
 
         with InsertionPoint(module.body):
-            a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
-            # CHECK: %cst = arith.constant 4.242000e+01 : f32
+            a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+            # CHECK: ArithValue(%cst = arith.constant 4.240
             print(a)
+
+            b = a + a
+            # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
+            print(b)
+
+            a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
+            b = a - a
+            # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
+            print(b)
+
+            a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
+            b = a * a
+            # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
+            print(b)

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 472db7e5124dbed..f313a400b73c0a5 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -425,6 +425,12 @@ def __str__(self):
             # And it should be equal to the in-tree concrete type
             assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
 
+            d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result
+            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>)
+            print(d)
+            # CHECK: TestTensorValue
+            print(repr(d))
+
 
 # CHECK-LABEL: TEST: inferReturnTypeComponents
 @run
@@ -502,19 +508,18 @@ def testCustomTypeTypeCaster():
         # CHECK: Type caster is already registered
         try:
 
+            @register_type_caster(c.typeid)
             def type_caster(pytype):
                 return test.TestIntegerRankedTensorType(pytype)
 
-            register_type_caster(c.typeid, type_caster)
         except RuntimeError as e:
             print(e)
 
-        def type_caster(pytype):
-            return RankedTensorType(pytype)
-
         # python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
         # So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
-        register_type_caster(c.typeid, type_caster, replace=True)
+        @register_type_caster(c.typeid, replace=True)
+        def type_caster(pytype):
+            return RankedTensorType(pytype)
 
         d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
         # CHECK: tensor<10x10xi5>
@@ -522,11 +527,10 @@ def type_caster(pytype):
         # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
         print("ranked tensor type", repr(d.type))
 
+        @register_type_caster(c.typeid, replace=True)
         def type_caster(pytype):
             return test.TestIntegerRankedTensorType(pytype)
 
-        register_type_caster(c.typeid, type_caster, replace=True)
-
         d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
         # CHECK: tensor<10x10xi5>
         print(d.type)

diff  --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index ddf653dcce27804..acbf463113a6d59 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -3,6 +3,7 @@
 import gc
 from mlir.ir import *
 from mlir.dialects import func
+from mlir.dialects._ods_common import SubClassValueT
 
 
 def run(f):
@@ -270,3 +271,120 @@ def testValueSetType():
 
             # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
             print(value.owner)
+
+
+# CHECK-LABEL: TEST: testValueCasters
+ at run
+def testValueCasters():
+    class NOPResult(OpResult):
+        def __init__(self, v):
+            super().__init__(v)
+
+        def __str__(self):
+            return super().__str__().replace(Value.__name__, NOPResult.__name__)
+
+    class NOPValue(Value):
+        def __init__(self, v):
+            super().__init__(v)
+
+        def __str__(self):
+            return super().__str__().replace(Value.__name__, NOPValue.__name__)
+
+    class NOPBlockArg(BlockArgument):
+        def __init__(self, v):
+            super().__init__(v)
+
+        def __str__(self):
+            return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
+
+    @register_value_caster(IntegerType.static_typeid)
+    def cast_int(v) -> SubClassValueT:
+        print("in caster", v.__class__.__name__)
+        if isinstance(v, OpResult):
+            return NOPResult(v)
+        if isinstance(v, BlockArgument):
+            return NOPBlockArg(v)
+        elif isinstance(v, Value):
+            return NOPValue(v)
+
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            values = Operation.create("custom.op1", results=[i32, i32]).results
+            # CHECK: in caster OpResult
+            # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", values[0].result_number, values[0])
+            # CHECK: in caster OpResult
+            # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", values[1].result_number, values[1])
+
+            # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("results slice", values[:1][0].result_number, values[:1][0])
+
+            value0, value1 = values
+            # CHECK: in caster OpResult
+            # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", value0.result_number, values[0])
+            # CHECK: in caster OpResult
+            # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", value1.result_number, values[1])
+
+            op1 = Operation.create("custom.op2", operands=[value0, value1])
+            # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
+            print(op1)
+
+            # CHECK: in caster Value
+            # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("operand 0", op1.operands[0])
+            # CHECK: in caster Value
+            # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("operand 1", op1.operands[1])
+
+            # CHECK: in caster BlockArgument
+            # CHECK: in caster BlockArgument
+            @func.FuncOp.from_py_func(i32, i32)
+            def reduction(arg0, arg1):
+                # CHECK: as func arg 0 NOPBlockArg
+                print("as func arg", arg0.arg_number, arg0.__class__.__name__)
+                # CHECK: as func arg 1 NOPBlockArg
+                print("as func arg", arg1.arg_number, arg1.__class__.__name__)
+
+            # CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0)
+            print(
+                "args slice",
+                reduction.func_op.arguments[:1][0].arg_number,
+                reduction.func_op.arguments[:1][0],
+            )
+
+    try:
+
+        @register_value_caster(IntegerType.static_typeid)
+        def dont_cast_int_shouldnt_register(v):
+            ...
+
+    except RuntimeError as e:
+        # CHECK: Value caster is already registered: {{.*}}cast_int
+        print(e)
+
+    @register_value_caster(IntegerType.static_typeid, replace=True)
+    def dont_cast_int(v) -> OpResult:
+        assert isinstance(v, OpResult)
+        print("don't cast", v.result_number, v)
+        return v
+
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32)
+            new_value = Operation.create("custom.op1", results=[i32]).result
+            # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32)
+            print("result", new_value.result_number, new_value)
+
+            # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32)
+            new_value = Operation.create("custom.op2", results=[i32]).results[0]
+            # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32)
+            print("result", new_value.result_number, new_value)

diff  --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f533082a0a147c0..aff414894cb825a 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -42,6 +42,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
             return cls(mlirPythonTestTestAttributeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
+
   mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
                      mlirPythonTestTestTypeGetTypeID)
       .def_classmethod(
@@ -50,7 +51,8 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
             return cls(mlirPythonTestTestTypeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
-  auto cls =
+
+  auto typeCls =
       mlir_type_subclass(m, "TestIntegerRankedTensorType",
                          mlirTypeIsARankedIntegerTensor,
                          py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
@@ -65,16 +67,40 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
                     encoding));
               },
               "cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
-  assert(py::hasattr(cls.get_class(), "static_typeid") &&
+
+  assert(py::hasattr(typeCls.get_class(), "static_typeid") &&
          "TestIntegerRankedTensorType has no static_typeid");
-  MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID();
+
+  MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
+
+  py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+      .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID,
+                                                        "replace"_a = true)(
+          pybind11::cpp_function([typeCls](const py::object &mlirType) {
+            return typeCls.get_class()(mlirType);
+          }));
+
+  auto valueCls = mlir_value_subclass(m, "TestTensorValue",
+                                      mlirTypeIsAPythonTestTestTensorValue)
+                      .def("is_null", [](MlirValue &self) {
+                        return mlirValueIsNull(self);
+                      });
+
   py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-      .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
-          mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) {
-            return cls.get_class()(mlirType);
-          }),
-          /*replace=*/true);
-  mlir_value_subclass(m, "TestTensorValue",
-                      mlirTypeIsAPythonTestTestTensorValue)
-      .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+      .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
+          mlirRankedTensorTypeID)(
+          pybind11::cpp_function([valueCls](const py::object &valueObj) {
+            py::object capsule = mlirApiObjectToCapsule(valueObj);
+            MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
+            MlirType t = mlirValueGetType(v);
+            // This is hyper-specific in order to exercise/test registering a
+            // value caster from cpp (but only for a single test case; see
+            // testTensorValue python_test.py).
+            if (mlirShapedTypeHasStaticShape(t) &&
+                mlirShapedTypeGetDimSize(t, 0) == 1 &&
+                mlirShapedTypeGetDimSize(t, 1) == 2 &&
+                mlirShapedTypeGetDimSize(t, 2) == 3)
+              return valueCls.get_class()(valueObj);
+            return valueObj;
+          }));
 }

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index c8ef84721090ab9..0c0ad2cfeffdcc2 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,15 @@ constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from ._ods_common import _cext as _ods_cext
-from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
+from ._ods_common import (
+    SubClassValueT as _SubClassValueT,
+    equally_sized_accessor as _ods_equally_sized_accessor,
+    get_default_loc_context as _ods_get_default_loc_context,
+    get_op_result_or_op_results as _get_op_result_or_op_results,
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+    segmented_accessor as _ods_segmented_accessor,
+)
 _ods_ir = _ods_cext.ir
 
 import builtins
@@ -1004,8 +1012,8 @@ static void emitValueBuilder(const Operator &op,
                       llvm::join(valueBuilderParams, ", "),
                       llvm::join(opBuilderArgs, ", "),
                       (op.getNumResults() > 1
-                           ? "_Sequence[_ods_ir.OpResult]"
-                           : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+                           ? "_Sequence[_SubClassValueT]"
+                           : (op.getNumResults() > 0 ? "_SubClassValueT"
                                                      : "_ods_ir.Operation")));
 }
 


        


More information about the Mlir-commits mailing list