[Mlir-commits] [mlir] [mlir][python] value casting (PR #69644)
Maksim Levental
llvmlistbot at llvm.org
Tue Oct 31 13:30:18 PDT 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/69644
>From eb08c68dbf7392ba09ceecc7cefb8d8d5c2f15c5 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 31 Oct 2023 11:58:46 -0500
Subject: [PATCH 1/7] [mlir][python] remove various caching mechanism
---
mlir/docs/Bindings/Python.md | 2 +-
mlir/lib/Bindings/Python/Globals.h | 20 ++---
mlir/lib/Bindings/Python/IRModule.cpp | 104 +++++++-----------------
mlir/lib/Bindings/Python/MainModule.cpp | 5 +-
4 files changed, 36 insertions(+), 95 deletions(-)
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index bc2e676a878c0f4..ef984e2bed7ea3a 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -945,7 +945,7 @@ When the python bindings need to locate a wrapper module, they consult the
`dialect_search_path` and use it to find an appropriately named module. For the
main repository, this search path is hard-coded to include the `mlir.dialects`
module, which is where wrappers are emitted by the above build rule. Out of tree
-dialects and add their modules to the search path by calling:
+dialects can add their modules to the search path by calling:
```python
mlir._cext.append_dialect_search_prefix("myproject.mlir.dialects")
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 21899bdce22e810..4332954f8b6927c 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -9,10 +9,6 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
-#include <optional>
-#include <string>
-#include <vector>
-
#include "PybindUtils.h"
#include "mlir-c/IR.h"
@@ -21,6 +17,10 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
+#include <optional>
+#include <string>
+#include <vector>
+
namespace mlir {
namespace python {
@@ -45,10 +45,6 @@ class PyGlobals {
dialectSearchPrefixes.swap(newValues);
}
- /// Clears positive and negative caches regarding what implementations are
- /// available. Future lookups will do more expensive existence checks.
- void clearImportCache();
-
/// Loads a python module corresponding to the given dialect namespace.
/// No-ops if the module has already been loaded or is not found. Raises
/// an error on any evaluation issues.
@@ -113,16 +109,10 @@ class PyGlobals {
llvm::StringMap<pybind11::object> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
- /// Cache for map of MlirTypeID to custom type caster.
- llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
- llvm::StringSet<> loadedDialectModulesCache;
- /// Cache of operation name to external operation class object. This is
- /// maintained on lookup as a shadow of operationClassMap in order for repeat
- /// lookups of the classes to only incur the cost of one hashtable lookup.
- llvm::StringMap<pybind11::object> operationClassMapCache;
+ llvm::StringSet<> loadedDialectModules;
};
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index f8e22f7bb0c1ba7..598c41012b3663d 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -10,12 +10,12 @@
#include "Globals.h"
#include "PybindUtils.h"
-#include <optional>
-#include <vector>
-
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
+#include <optional>
+#include <vector>
+
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
@@ -37,7 +37,7 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
- if (loadedDialectModulesCache.contains(dialectNamespace))
+ if (loadedDialectModules.contains(dialectNamespace))
return;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
@@ -59,13 +59,13 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
- loadedDialectModulesCache.insert(dialectNamespace);
+ loadedDialectModules.insert(dialectNamespace);
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
py::function pyFunc, bool replace) {
py::object &found = attributeBuilderMap[attributeKind];
- if (found && !found.is_none() && !replace) {
+ if (found && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
attributeKind +
"' is already registered with func: " +
@@ -79,13 +79,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
pybind11::function typeCaster,
bool replace) {
pybind11::object &found = typeCasterMap[mlirTypeID];
- if (found && !found.is_none() && !replace)
- throw std::runtime_error("Type caster is already registered");
+ if (found && !replace)
+ throw std::runtime_error("Type caster is already registered with caster: " +
+ py::str(found).operator std::string());
found = std::move(typeCaster);
- const auto foundIt = typeCasterMapCache.find(mlirTypeID);
- if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
- typeCasterMapCache[mlirTypeID] = found;
- }
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
@@ -108,86 +105,51 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
.str());
}
found = std::move(pyClass);
- auto foundIt = operationClassMapCache.find(operationName);
- if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
- operationClassMapCache[operationName] = found;
- }
}
std::optional<py::function>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
- // Fast match against the class map first (common case).
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::function is defined");
+ assert(foundIt->second && "attribute builder is defined");
return foundIt->second;
}
-
- // Not found and loading did not yield a registration. Negative cache.
- attributeBuilderMap[attributeKind] = py::none();
return std::nullopt;
}
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
- {
- // Fast match against the class map first (common case).
- const auto foundIt = typeCasterMapCache.find(mlirTypeID);
- if (foundIt != typeCasterMapCache.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::function is defined");
- return foundIt->second;
- }
- }
-
- // Not found. Load the dialect namespace.
+ // Make sure dialect module is loaded.
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
-
- // Attempt to find from the canonical map and cache.
- {
- const auto foundIt = typeCasterMap.find(mlirTypeID);
- if (foundIt != typeCasterMap.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::object is defined");
- // Positive cache.
- typeCasterMapCache[mlirTypeID] = foundIt->second;
- return foundIt->second;
- }
- // Negative cache.
- typeCasterMap[mlirTypeID] = py::none();
- return std::nullopt;
+ const auto foundIt = typeCasterMap.find(mlirTypeID);
+ if (foundIt != typeCasterMap.end()) {
+ assert(foundIt->second && "type caster is defined");
+ return foundIt->second;
}
+ return std::nullopt;
}
std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
+ // Make sure dialect module is loaded.
loadDialectModule(dialectNamespace);
- // Fast match against the class map first (common case).
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::object is defined");
+ assert(foundIt->second && "dialect class is defined");
return foundIt->second;
}
-
- // Not found and loading did not yield a registration. Negative cache.
- dialectClassMap[dialectNamespace] = py::none();
+ // Not found and loading did not yield a registration.
return std::nullopt;
}
std::optional<pybind11::object>
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
+ // Fast match against the class map first (succeeds if second lookup, after
+ // successful dialect load).
{
- auto foundIt = operationClassMapCache.find(operationName);
- if (foundIt != operationClassMapCache.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::object is defined");
+ auto foundIt = operationClassMap.find(operationName);
+ if (foundIt != operationClassMap.end()) {
+ assert(foundIt->second && "OpView is defined");
return foundIt->second;
}
}
@@ -197,25 +159,15 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
llvm::StringRef dialectNamespace = split.first;
loadDialectModule(dialectNamespace);
- // Attempt to find from the canonical map and cache.
+ // Try again to load from class map after successful dialect load.
{
auto foundIt = operationClassMap.find(operationName);
if (foundIt != operationClassMap.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::object is defined");
- // Positive cache.
- operationClassMapCache[operationName] = foundIt->second;
+ assert(foundIt->second && "OpView is defined");
return foundIt->second;
}
- // Negative cache.
- operationClassMap[operationName] = py::none();
- return std::nullopt;
}
-}
-void PyGlobals::clearImportCache() {
- loadedDialectModulesCache.clear();
- operationClassMapCache.clear();
- typeCasterMapCache.clear();
+ // Not found and loading did not yield a registration.
+ return std::nullopt;
}
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index a936becf67bea75..2b6248321c1c110 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,14 +6,14 @@
//
//===----------------------------------------------------------------------===//
-#include <tuple>
-
#include "PybindUtils.h"
#include "Globals.h"
#include "IRModule.h"
#include "Pass.h"
+#include <tuple>
+
namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
@@ -34,7 +34,6 @@ PYBIND11_MODULE(_mlir, m) {
"append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
- self.clearImportCache();
},
"module_name"_a)
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
>From e366033f32175a13f6b350cd4856d01583ce56d3 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 11 Oct 2023 00:28:04 -0500
Subject: [PATCH 2/7] [mlir][python] value casting
---
mlir/python/mlir/dialects/_ods_common.py | 58 +++++++++++++++-
mlir/python/mlir/ir.py | 14 ++++
mlir/test/mlir-tblgen/op-python-bindings.td | 48 ++++++-------
mlir/test/python/dialects/arith_dialect.py | 68 +++++++++++++++++--
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 17 +++--
5 files changed, 171 insertions(+), 34 deletions(-)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 9cca7d659ec8cb3..dd41ee63c8bf7af 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -1,11 +1,18 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from collections import defaultdict
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
-from typing import Sequence as _Sequence, Union as _Union
+from typing import (
+ Callable as _Callable,
+ Sequence as _Sequence,
+ Type as _Type,
+ TypeVar as _TypeVar,
+ Union as _Union,
+)
__all__ = [
"equally_sized_accessor",
@@ -123,3 +130,52 @@ def get_op_result_or_op_results(
if len(op.results) > 0
else op
)
+
+
+U = _TypeVar("U", bound=_cext.ir.Value)
+SubClassValueT = _Type[U]
+
+ValueCasterT = _Callable[
+ [_Union[_cext.ir.Value, _cext.ir.OpResult]], _Union[SubClassValueT, None]
+]
+
+_VALUE_CASTERS: defaultdict[
+ _cext.ir.TypeID,
+ _Sequence[ValueCasterT],
+] = defaultdict(list)
+
+
+def has_value_caster(typeid: _cext.ir.TypeID):
+ if not isinstance(typeid, _cext.ir.TypeID):
+ raise ValueError(f"{typeid=} is not a TypeID")
+ if typeid in _VALUE_CASTERS:
+ return True
+ return False
+
+
+def get_value_caster(typeid: _cext.ir.TypeID):
+ if not has_value_caster(typeid):
+ raise ValueError(f"no registered caster for {typeid=}")
+ return _VALUE_CASTERS[typeid]
+
+
+def maybe_cast(
+ val: _Union[
+ _cext.ir.Value,
+ _cext.ir.OpResult,
+ _Sequence[_cext.ir.Value],
+ _Sequence[_cext.ir.OpResult],
+ _cext.ir.Operation,
+ ]
+) -> _Union[SubClassValueT, _Sequence[SubClassValueT], _cext.ir.Operation]:
+ if isinstance(val, (tuple, list)):
+ return tuple(map(maybe_cast, val))
+
+ if not isinstance(val, _cext.ir.Value) and not isinstance(val, _cext.ir.OpResult):
+ return val
+
+ if has_value_caster(val.type.typeid):
+ for caster in get_value_caster(val.type.typeid):
+ if casted := caster(val):
+ return casted
+ return val
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 43553f3118a51fc..6e1f2b357f31711 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,6 +5,20 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster
+from .dialects._ods_common import ValueCasterT, _VALUE_CASTERS
+
+
+def register_value_caster(typeid: TypeID, priority: int = None):
+ def wrapper(caster: ValueCasterT):
+ if not isinstance(typeid, TypeID):
+ raise ValueError(f"{typeid=} is not a TypeID")
+ if priority is None:
+ _VALUE_CASTERS[typeid].append(caster)
+ else:
+ _VALUE_CASTERS[typeid].insert(priority, caster)
+ return caster
+
+ return wrapper
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 63dad1cc901fe2b..96b0c170dc5bb40 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -61,7 +61,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
}
// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -108,7 +108,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
}
// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -158,7 +158,7 @@ def AttributedOp : TestOp<"attributed_op"> {
}
// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -194,7 +194,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
}
// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -218,7 +218,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
}
// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)))
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -236,7 +236,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
}
// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)))
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -246,7 +246,7 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
}
// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ods_ir.OpView):
@@ -263,7 +263,7 @@ def EmptyOp : TestOp<"empty">;
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
// CHECK: def empty(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip)))
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -276,7 +276,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
}
// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)))
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -289,7 +289,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
}
// CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -327,7 +327,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
}
// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -358,7 +358,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
}
// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -390,7 +390,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
}
// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -423,7 +423,7 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
}
// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
@@ -447,7 +447,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
}
// CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip)))
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -461,7 +461,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
}
// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)))
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -471,7 +471,7 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
}
// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -498,7 +498,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
}
// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -524,7 +524,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
}
// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SimpleOp(_ods_ir.OpView):
@@ -564,7 +564,7 @@ def SimpleOp : TestOp<"simple"> {
}
// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)))
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -591,7 +591,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
}
// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
// CHECK: class VariadicRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -614,7 +614,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
}
// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -623,7 +623,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
}
// CHECK: def _123with__special_characters(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
}
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)))
\ No newline at end of file
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 6d1c5eab7589847..180d30ff4cfb3e5 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -1,8 +1,9 @@
# RUN: %PYTHON %s | FileCheck %s
+from functools import partialmethod
from mlir.ir import *
-import mlir.dialects.func as func
import mlir.dialects.arith as arith
+from mlir.dialects._ods_common import maybe_cast
def run(f):
@@ -35,14 +36,71 @@ def testFastMathFlags():
print(r)
-# CHECK-LABEL: TEST: testArithValueBuilder
+# CHECK-LABEL: TEST: testArithValue
@run
-def testArithValueBuilder():
+def testArithValue():
+ def _binary_op(lhs, rhs, op: str):
+ op = op.capitalize()
+ if arith._is_float_type(lhs.type):
+ op += "F"
+ elif arith._is_integer_like_type(lhs.type):
+ op += "I"
+ else:
+ raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
+
+ op = getattr(arith, f"{op}Op")
+ return maybe_cast(op(lhs, rhs).result)
+
+ @register_value_caster(F16Type.static_typeid)
+ @register_value_caster(F32Type.static_typeid)
+ @register_value_caster(F64Type.static_typeid)
+ @register_value_caster(IntegerType.static_typeid)
+ class ArithValue(Value):
+ __add__ = partialmethod(_binary_op, op="add")
+ __sub__ = partialmethod(_binary_op, op="sub")
+ __mul__ = partialmethod(_binary_op, op="mul")
+
+ def __str__(self):
+ return super().__str__().replace("Value", "ArithValue")
+
+ @register_value_caster(IntegerType.static_typeid, priority=0)
+ class ArithValue1(Value):
+ __mul__ = partialmethod(_binary_op, op="mul")
+
+ def __str__(self):
+ return super().__str__().replace("Value", "ArithValue1")
+
+ @register_value_caster(IntegerType.static_typeid, priority=0)
+ def no_op_caster(val):
+ print("no_op_caster", val)
+ return None
+
with Context() as ctx, Location.unknown():
module = Module.create()
+ f16_t = F16Type.get()
f32_t = F32Type.get()
+ f64_t = F64Type.get()
+ i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
+ a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+ b = a + a
+ # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
+ print(b)
+
a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
- # CHECK: %cst = arith.constant 4.242000e+01 : f32
- print(a)
+ b = a - a
+ # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
+ print(b)
+
+ a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
+ b = a * a
+ # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
+ print(b)
+
+ # CHECK: no_op_caster Value(%c1_i32 = arith.constant 1 : i32)
+ a = arith.constant(value=IntegerAttr.get(i32, 1))
+ b = a * a
+ # CHECK: no_op_caster Value(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+ # CHECK: ArithValue1(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+ print(b)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index c8ef84721090ab9..170ac6b87c693d7 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,16 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
-from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
+from ._ods_common import (
+ SubClassValueT as _SubClassValueT,
+ equally_sized_accessor as _ods_equally_sized_accessor,
+ get_default_loc_context as _ods_get_default_loc_context,
+ get_op_result_or_op_results as _get_op_result_or_op_results,
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ maybe_cast as _maybe_cast,
+ segmented_accessor as _ods_segmented_accessor,
+)
_ods_ir = _ods_cext.ir
import builtins
@@ -263,7 +272,7 @@ constexpr const char *regionAccessorTemplate = R"Py(
constexpr const char *valueBuilderTemplate = R"Py(
def {0}({2}) -> {4}:
- return _get_op_result_or_op_results({1}({3}))
+ return _maybe_cast(_get_op_result_or_op_results({1}({3})))
)Py";
static llvm::cl::OptionCategory
@@ -1004,8 +1013,8 @@ static void emitValueBuilder(const Operator &op,
llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "),
(op.getNumResults() > 1
- ? "_Sequence[_ods_ir.OpResult]"
- : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+ ? "_Sequence[_SubClassValueT]"
+ : (op.getNumResults() > 0 ? "_SubClassValueT"
: "_ods_ir.Operation")));
}
>From b9c33cb349fe1568617f5394f2bb1481f0d5bee4 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 23 Oct 2023 10:07:12 -0500
Subject: [PATCH 3/7] add new line to op-python-bindings.td
---
mlir/test/mlir-tblgen/op-python-bindings.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 96b0c170dc5bb40..9844040f8a33c4b 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
}
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)))
\ No newline at end of file
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)))
>From 93184a9b3f8411e0051706b3b2ca2315f3783422 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 25 Oct 2023 14:23:52 -0500
Subject: [PATCH 4/7] WIP opresult and opoperand and blockarg casting
---
.../mlir/Bindings/Python/PybindAdaptors.h | 1 +
mlir/lib/Bindings/Python/Globals.h | 16 ++++
mlir/lib/Bindings/Python/IRCore.cpp | 52 +++++++++-
mlir/lib/Bindings/Python/IRModule.cpp | 46 +++++++++
mlir/lib/Bindings/Python/IRModule.h | 10 +-
mlir/lib/Bindings/Python/MainModule.cpp | 12 +++
mlir/lib/Bindings/Python/PybindUtils.h | 2 +-
mlir/python/mlir/dialects/_ods_common.py | 52 +---------
mlir/python/mlir/ir.py | 16 +---
mlir/test/mlir-tblgen/op-python-bindings.td | 48 +++++-----
mlir/test/python/dialects/arith_dialect.py | 28 ++----
mlir/test/python/ir/value.py | 96 +++++++++++++++++++
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 3 +-
13 files changed, 260 insertions(+), 122 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 49680c8b79b135e..acc90e4ab9a22b8 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -234,6 +234,7 @@ struct type_caster<MlirValue> {
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Value")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
.release();
};
};
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 4332954f8b6927c..b06b39fb7515781 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -66,6 +66,13 @@ class PyGlobals {
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
bool replace = false);
+ /// Adds a user-friendly value caster. Raises an exception if the mapping
+ /// already exists and replace == false. This is intended to be called by
+ /// implementation code.
+ void registerValueCaster(MlirTypeID mlirTypeID,
+ pybind11::function valueCaster,
+ bool replace = false);
+
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
@@ -86,6 +93,10 @@ class PyGlobals {
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);
+ /// Returns the custom value caster for MlirTypeID mlirTypeID.
+ std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
+ MlirDialect dialect);
+
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
@@ -110,6 +121,11 @@ class PyGlobals {
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
+ /// Map of MlirTypeID to custom value caster.
+ llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
+ /// Cache for map of MlirTypeID to custom value caster.
+ llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMapCache;
+
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7cfea31dbb2e80c..2c7ffda4e088032 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1899,13 +1899,26 @@ bool PyTypeID::operator==(const PyTypeID &other) const {
}
//------------------------------------------------------------------------------
-// PyValue and subclases.
+// PyValue and subclasses.
//------------------------------------------------------------------------------
pybind11::object PyValue::getCapsule() {
return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
}
+pybind11::object PyValue::maybeDownCast() {
+ MlirType type = mlirValueGetType(get());
+ MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
+ assert(!mlirTypeIDIsNull(mlirTypeID) &&
+ "mlirTypeID was expected to be non-null.");
+ std::optional<pybind11::function> valueCaster =
+ PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
+ py::object this_ = py::cast(this, py::return_value_policy::move);
+ if (!valueCaster)
+ return this_;
+ return valueCaster.value()(this_);
+}
+
PyValue PyValue::createFromCapsule(pybind11::object capsule) {
MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
if (mlirValueIsNull(value))
@@ -2121,6 +2134,8 @@ class PyConcreteValue : public PyValue {
return DerivedTy::isaFunction(otherValue);
},
py::arg("other_value"));
+ cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](DerivedTy &self) { return self.maybeDownCast(); });
DerivedTy::bindDerived(cls);
}
@@ -2193,6 +2208,7 @@ class PyBlockArgumentList
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
public:
static constexpr const char *pyClassName = "BlockArgumentList";
+ using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
intptr_t startIndex = 0, intptr_t length = -1,
@@ -2202,6 +2218,13 @@ class PyBlockArgumentList
step),
operation(std::move(operation)), block(block) {}
+ pybind11::object getItem(intptr_t index) override {
+ auto item = this->SliceableT::getItem(index);
+ if (item.ptr() != nullptr)
+ return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
+ return item;
+ }
+
static void bindDerived(ClassTy &c) {
c.def_property_readonly("types", [](PyBlockArgumentList &self) {
return getValueTypes(self, self.operation->getContext());
@@ -2241,6 +2264,7 @@ class PyBlockArgumentList
class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
public:
static constexpr const char *pyClassName = "OpOperandList";
+ using SliceableT = Sliceable<PyOpOperandList, PyValue>;
PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
@@ -2250,6 +2274,13 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
step),
operation(operation) {}
+ pybind11::object getItem(intptr_t index) override {
+ auto item = this->SliceableT::getItem(index);
+ if (item.ptr() != nullptr)
+ return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
+ return item;
+ }
+
void dunderSetItem(intptr_t index, PyValue value) {
index = wrapIndex(index);
mlirOperationSetOperand(operation->get(), index, value.get());
@@ -2296,6 +2327,7 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
public:
static constexpr const char *pyClassName = "OpResultList";
+ using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
@@ -2303,7 +2335,14 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
length == -1 ? mlirOperationGetNumResults(operation->get())
: length,
step),
- operation(operation) {}
+ operation(std::move(operation)) {}
+
+ pybind11::object getItem(intptr_t index) override {
+ auto item = this->SliceableT::getItem(index);
+ if (item.ptr() != nullptr)
+ return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
+ return item;
+ }
static void bindDerived(ClassTy &c) {
c.def_property_readonly("types", [](PyOpResultList &self) {
@@ -2891,8 +2930,9 @@ void mlir::python::populateIRCore(py::module &m) {
"single result)")
.str());
}
- return PyOpResult(operation.getRef(),
- mlirOperationGetResult(operation, 0));
+ PyOpResult result = PyOpResult(
+ operation.getRef(), mlirOperationGetResult(operation, 0));
+ return result.maybeDownCast();
},
"Shortcut to get an op result if it has only one (throws an error "
"otherwise).")
@@ -3566,7 +3606,9 @@ void mlir::python::populateIRCore(py::module &m) {
[](PyValue &self, PyValue &with) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
- kValueReplaceAllUsesWithDocstring);
+ kValueReplaceAllUsesWithDocstring)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyValue &self) { return self.maybeDownCast(); });
PyBlockArgument::bind(m);
PyOpResult::bind(m);
PyOpOperand::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 598c41012b3663d..1fb7494ef6d2958 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -85,6 +85,19 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
found = std::move(typeCaster);
}
+void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
+ pybind11::function valueCaster,
+ bool replace) {
+ pybind11::object &found = valueCasterMap[mlirTypeID];
+ if (found && !found.is_none() && !replace)
+ throw std::runtime_error("Value caster is already registered");
+ found = std::move(valueCaster);
+ const auto foundIt = valueCasterMapCache.find(mlirTypeID);
+ if (foundIt != valueCasterMapCache.end() && !foundIt->second.is_none()) {
+ valueCasterMapCache[mlirTypeID] = found;
+ }
+}
+
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
@@ -129,6 +142,39 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
return std::nullopt;
}
+std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
+ MlirDialect dialect) {
+ {
+ // Fast match against the value caster map first (common case).
+ const auto foundIt = valueCasterMapCache.find(mlirTypeID);
+ if (foundIt != valueCasterMapCache.end()) {
+ if (foundIt->second.is_none())
+ return std::nullopt;
+ assert(foundIt->second && "py::function is defined");
+ return foundIt->second;
+ }
+ }
+
+ // Not found. Load the dialect namespace.
+ loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+
+ // Attempt to find from the canonical map and cache.
+ {
+ const auto foundIt = valueCasterMap.find(mlirTypeID);
+ if (foundIt != valueCasterMap.end()) {
+ if (foundIt->second.is_none())
+ return std::nullopt;
+ assert(foundIt->second && "py::object is defined");
+ // Positive cache.
+ valueCasterMapCache[mlirTypeID] = foundIt->second;
+ return foundIt->second;
+ }
+ // Negative cache.
+ valueCasterMap[mlirTypeID] = py::none();
+ return std::nullopt;
+ }
+}
+
std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 01ee4975d0e9a91..b95c4578fbc2220 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -761,7 +761,7 @@ class PyRegion {
/// Wrapper around an MlirAsmState.
class PyAsmState {
- public:
+public:
PyAsmState(MlirValue value, bool useLocalScope) {
flags = mlirOpPrintingFlagsCreate();
// The OpPrintingFlags are not exposed Python side, create locally and
@@ -780,16 +780,14 @@ class PyAsmState {
state =
mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
}
- ~PyAsmState() {
- mlirOpPrintingFlagsDestroy(flags);
- }
+ ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
// Delete copy constructors.
PyAsmState(PyAsmState &other) = delete;
PyAsmState(const PyAsmState &other) = delete;
MlirAsmState get() { return state; }
- private:
+private:
MlirAsmState state;
MlirOpPrintingFlags flags;
};
@@ -1124,6 +1122,8 @@ class PyValue {
/// Gets a capsule wrapping the void* within the MlirValue.
pybind11::object getCapsule();
+ virtual pybind11::object maybeDownCast();
+
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
static PyValue createFromCapsule(pybind11::object capsule);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 2b6248321c1c110..8c64437aad41567 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -87,6 +87,18 @@ PYBIND11_MODULE(_mlir, m) {
},
"typeid"_a, "type_caster"_a, "replace"_a = false,
"Register a type caster for casting MLIR types to custom user types.");
+ m.def(
+ "register_value_caster",
+ [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
+ return py::cpp_function(
+ [mlirTypeID, replace](py::object valueCaster) -> py::object {
+ PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
+ replace);
+ return valueCaster;
+ });
+ },
+ "typeid"_a, "replace"_a = false,
+ "Register a value caster for casting MLIR values to custom user values.");
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 2a8da20bee0495d..efb7b713f80a40c 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -231,7 +231,7 @@ class Sliceable {
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
- pybind11::object getItem(intptr_t index) {
+ virtual pybind11::object getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index dd41ee63c8bf7af..fa73c197c17faf6 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -7,7 +7,6 @@
# with a relative import.
from .._mlir_libs import _mlir as _cext
from typing import (
- Callable as _Callable,
Sequence as _Sequence,
Type as _Type,
TypeVar as _TypeVar,
@@ -132,50 +131,7 @@ def get_op_result_or_op_results(
)
-U = _TypeVar("U", bound=_cext.ir.Value)
-SubClassValueT = _Type[U]
-
-ValueCasterT = _Callable[
- [_Union[_cext.ir.Value, _cext.ir.OpResult]], _Union[SubClassValueT, None]
-]
-
-_VALUE_CASTERS: defaultdict[
- _cext.ir.TypeID,
- _Sequence[ValueCasterT],
-] = defaultdict(list)
-
-
-def has_value_caster(typeid: _cext.ir.TypeID):
- if not isinstance(typeid, _cext.ir.TypeID):
- raise ValueError(f"{typeid=} is not a TypeID")
- if typeid in _VALUE_CASTERS:
- return True
- return False
-
-
-def get_value_caster(typeid: _cext.ir.TypeID):
- if not has_value_caster(typeid):
- raise ValueError(f"no registered caster for {typeid=}")
- return _VALUE_CASTERS[typeid]
-
-
-def maybe_cast(
- val: _Union[
- _cext.ir.Value,
- _cext.ir.OpResult,
- _Sequence[_cext.ir.Value],
- _Sequence[_cext.ir.OpResult],
- _cext.ir.Operation,
- ]
-) -> _Union[SubClassValueT, _Sequence[SubClassValueT], _cext.ir.Operation]:
- if isinstance(val, (tuple, list)):
- return tuple(map(maybe_cast, val))
-
- if not isinstance(val, _cext.ir.Value) and not isinstance(val, _cext.ir.OpResult):
- return val
-
- if has_value_caster(val.type.typeid):
- for caster in get_value_caster(val.type.typeid):
- if casted := caster(val):
- return casted
- return val
+# This is the standard way to indicate subclass/inheritance relationship
+# see the typing.Type doc string.
+_U = _TypeVar("_U", bound=_cext.ir.Value)
+SubClassValueT = _Type[_U]
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 6e1f2b357f31711..eede64de674e22b 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,21 +4,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
-from ._mlir_libs._mlir import register_type_caster
-from .dialects._ods_common import ValueCasterT, _VALUE_CASTERS
-
-
-def register_value_caster(typeid: TypeID, priority: int = None):
- def wrapper(caster: ValueCasterT):
- if not isinstance(typeid, TypeID):
- raise ValueError(f"{typeid=} is not a TypeID")
- if priority is None:
- _VALUE_CASTERS[typeid].append(caster)
- else:
- _VALUE_CASTERS[typeid].insert(priority, caster)
- return caster
-
- return wrapper
+from ._mlir_libs._mlir import register_type_caster, register_value_caster
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 9844040f8a33c4b..f7df8ba2df0ae2f 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -61,7 +61,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
}
// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -108,7 +108,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
}
// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -158,7 +158,7 @@ def AttributedOp : TestOp<"attributed_op"> {
}
// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -194,7 +194,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
}
// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -218,7 +218,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
}
// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -236,7 +236,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
}
// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -246,7 +246,7 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
}
// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ods_ir.OpView):
@@ -263,7 +263,7 @@ def EmptyOp : TestOp<"empty">;
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
// CHECK: def empty(*, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -276,7 +276,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
}
// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -289,7 +289,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
}
// CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -327,7 +327,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
}
// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -358,7 +358,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
}
// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -390,7 +390,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
}
// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -423,7 +423,7 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
}
// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
@@ -447,7 +447,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
}
// CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -461,7 +461,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
}
// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -471,7 +471,7 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
}
// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -498,7 +498,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
}
// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -524,7 +524,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
}
// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SimpleOp(_ods_ir.OpView):
@@ -564,7 +564,7 @@ def SimpleOp : TestOp<"simple"> {
}
// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -591,7 +591,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
}
// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
// CHECK: class VariadicRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -614,7 +614,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
}
// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -623,7 +623,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
}
// CHECK: def _123with__special_characters(*, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
}
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)))
+// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 180d30ff4cfb3e5..39c3d5799a6563a 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -3,7 +3,7 @@
from mlir.ir import *
import mlir.dialects.arith as arith
-from mlir.dialects._ods_common import maybe_cast
+import mlir.dialects.func as func
def run(f):
@@ -49,31 +49,22 @@ def _binary_op(lhs, rhs, op: str):
raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
op = getattr(arith, f"{op}Op")
- return maybe_cast(op(lhs, rhs).result)
+ return op(lhs, rhs).result
@register_value_caster(F16Type.static_typeid)
@register_value_caster(F32Type.static_typeid)
@register_value_caster(F64Type.static_typeid)
@register_value_caster(IntegerType.static_typeid)
class ArithValue(Value):
+ def __init__(self, v):
+ super().__init__(v)
+
__add__ = partialmethod(_binary_op, op="add")
__sub__ = partialmethod(_binary_op, op="sub")
__mul__ = partialmethod(_binary_op, op="mul")
def __str__(self):
- return super().__str__().replace("Value", "ArithValue")
-
- @register_value_caster(IntegerType.static_typeid, priority=0)
- class ArithValue1(Value):
- __mul__ = partialmethod(_binary_op, op="mul")
-
- def __str__(self):
- return super().__str__().replace("Value", "ArithValue1")
-
- @register_value_caster(IntegerType.static_typeid, priority=0)
- def no_op_caster(val):
- print("no_op_caster", val)
- return None
+ return super().__str__().replace(Value.__name__, ArithValue.__name__)
with Context() as ctx, Location.unknown():
module = Module.create()
@@ -97,10 +88,3 @@ def no_op_caster(val):
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
print(b)
-
- # CHECK: no_op_caster Value(%c1_i32 = arith.constant 1 : i32)
- a = arith.constant(value=IntegerAttr.get(i32, 1))
- b = a * a
- # CHECK: no_op_caster Value(%3 = arith.muli %c1_i32, %c1_i32 : i32)
- # CHECK: ArithValue1(%3 = arith.muli %c1_i32, %c1_i32 : i32)
- print(b)
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index ddf653dcce27804..1c3e1a6ae9654fe 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -270,3 +270,99 @@ def testValueSetType():
# CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
print(value.owner)
+
+
+# CHECK-LABEL: TEST: testValueCasters
+ at run
+def testValueCasters():
+ class NOPResult(OpResult):
+ def __init__(self, v):
+ super().__init__(v)
+
+ def __str__(self):
+ return super().__str__().replace(Value.__name__, NOPResult.__name__)
+
+ class NOPValue(Value):
+ def __init__(self, v):
+ super().__init__(v)
+
+ def __str__(self):
+ return super().__str__().replace(Value.__name__, NOPValue.__name__)
+
+ class NOPBlockArg(BlockArgument):
+ def __init__(self, v):
+ super().__init__(v)
+
+ def __str__(self):
+ return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
+
+ @register_value_caster(IntegerType.static_typeid)
+ def cast_int(v):
+ print("in caster", v.__class__.__name__)
+ if isinstance(v, OpResult):
+ return NOPResult(v)
+ if isinstance(v, BlockArgument):
+ return NOPBlockArg(v)
+ elif isinstance(v, Value):
+ return NOPValue(v)
+
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ values = Operation.create("custom.op1", results=[i32, i32]).results
+ # CHECK: in caster OpResult
+ # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", values[0].result_number, values[0])
+ # CHECK: in caster OpResult
+ # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", values[1].result_number, values[0])
+
+ value0, value1 = values
+ # CHECK: in caster OpResult
+ # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", value0.result_number, values[0])
+ # CHECK: in caster OpResult
+ # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", value1.result_number, values[0])
+
+ op1 = Operation.create("custom.op2", operands=[value0, value1])
+ # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
+ print(op1)
+
+ # CHECK: in caster Value
+ # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("operand 0", op1.operands[0])
+ # CHECK: in caster Value
+ # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("operand 1", op1.operands[1])
+
+ # CHECK: in caster BlockArgument
+ # CHECK: in caster BlockArgument
+ @func.FuncOp.from_py_func(i32, i32)
+ def reduction(arg0, arg1):
+ # CHECK: as func arg 0 NOPBlockArg
+ print("as func arg", arg0.arg_number, arg0.__class__.__name__)
+ # CHECK: as func arg 1 NOPBlockArg
+ print("as func arg", arg1.arg_number, arg1.__class__.__name__)
+
+ @register_value_caster(IntegerType.static_typeid, replace=True)
+ def dont_cast_int(v):
+ print("don't cast", v.result_number, v)
+ return v
+
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32)
+ new_value = Operation.create("custom.op1", results=[i32]).result
+ # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32)
+ print("result", new_value.result_number, new_value)
+
+ # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32)
+ new_value = Operation.create("custom.op2", results=[i32]).results[0]
+ # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32)
+ print("result", new_value.result_number, new_value)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 170ac6b87c693d7..0c0ad2cfeffdcc2 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -37,7 +37,6 @@ from ._ods_common import (
get_op_result_or_op_results as _get_op_result_or_op_results,
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
- maybe_cast as _maybe_cast,
segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
@@ -272,7 +271,7 @@ constexpr const char *regionAccessorTemplate = R"Py(
constexpr const char *valueBuilderTemplate = R"Py(
def {0}({2}) -> {4}:
- return _maybe_cast(_get_op_result_or_op_results({1}({3})))
+ return _get_op_result_or_op_results({1}({3}))
)Py";
static llvm::cl::OptionCategory
>From 41fccad77f95f524e9e92ce90c9bbbc39f263ccf Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 30 Oct 2023 13:34:21 -0500
Subject: [PATCH 5/7] done with opresult, blockarg casting
---
mlir/include/mlir-c/Bindings/Python/Interop.h | 18 ++++++++-
mlir/lib/Bindings/Python/IRCore.cpp | 12 +++---
mlir/lib/Bindings/Python/IRModule.cpp | 3 +-
mlir/lib/Bindings/Python/MainModule.cpp | 2 +-
mlir/python/mlir/dialects/_ods_common.py | 1 -
mlir/test/python/dialects/arith_dialect.py | 3 +-
mlir/test/python/dialects/python_test.py | 6 +++
mlir/test/python/ir/value.py | 29 ++++++++++++--
mlir/test/python/lib/PythonTestModule.cpp | 40 +++++++++++++++----
9 files changed, 90 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index f79c10cb9383829..9b026a6b922de47 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -121,10 +121,26 @@
* def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
* bool replace)
* where replace indicates the typeCaster should replace any existing registered
- * type casters (such as those for upstream ConcreteTypes).
+ * type casters (such as those for upstream ConcreteTypes). The interface of the
+ * typeCaster is:
+ * def type_caster(ir.Type) -> SubClassTypeT
+ * where SubClassTypeT indicates the result should be a subclass (inherit from)
+ * ir.Type.
*/
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
+/** Attribute on main C extension module (_mlir) that corresponds to the
+ * value caster registration binding. The signature of the function is:
+ * def register_value_caster(MlirTypeID mlirTypeID, bool replace,
+ * py::function valueCaster)
+ * where replace indicates the valueCaster should replace any existing
+ * registered value casters. The interface of the valueCaster is:
+ * def value_caster(ir.Value) -> SubClassValueT
+ * where SubClassValueT indicates the result should be a subclass (inherit from)
+ * ir.Value.
+ */
+#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster"
+
/// Gets a void* from a wrapped struct. Needed because const cast is different
/// between C/C++.
#ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2c7ffda4e088032..53eb75f810c1845 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1913,10 +1913,10 @@ pybind11::object PyValue::maybeDownCast() {
"mlirTypeID was expected to be non-null.");
std::optional<pybind11::function> valueCaster =
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
- py::object this_ = py::cast(this, py::return_value_policy::move);
+ py::object thisObj = py::cast(this, py::return_value_policy::move);
if (!valueCaster)
- return this_;
- return valueCaster.value()(this_);
+ return thisObj;
+ return valueCaster.value()(thisObj);
}
PyValue PyValue::createFromCapsule(pybind11::object capsule) {
@@ -2930,9 +2930,9 @@ void mlir::python::populateIRCore(py::module &m) {
"single result)")
.str());
}
- PyOpResult result = PyOpResult(
- operation.getRef(), mlirOperationGetResult(operation, 0));
- return result.maybeDownCast();
+ return PyOpResult(operation.getRef(),
+ mlirOperationGetResult(operation, 0))
+ .maybeDownCast();
},
"Shortcut to get an op result if it has only one (throws an error "
"otherwise).")
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 1fb7494ef6d2958..45f2df9e64190e8 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -90,7 +90,8 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
bool replace) {
pybind11::object &found = valueCasterMap[mlirTypeID];
if (found && !found.is_none() && !replace)
- throw std::runtime_error("Value caster is already registered");
+ throw std::runtime_error("Value caster is already registered: " +
+ py::repr(found).cast<std::string>());
found = std::move(valueCaster);
const auto foundIt = valueCasterMapCache.find(mlirTypeID);
if (foundIt != valueCasterMapCache.end() && !foundIt->second.is_none()) {
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 8c64437aad41567..2d322829c6d1030 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -88,7 +88,7 @@ PYBIND11_MODULE(_mlir, m) {
"typeid"_a, "type_caster"_a, "replace"_a = false,
"Register a type caster for casting MLIR types to custom user types.");
m.def(
- "register_value_caster",
+ MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
return py::cpp_function(
[mlirTypeID, replace](py::object valueCaster) -> py::object {
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index fa73c197c17faf6..60ce83c09f1717e 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -1,7 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from collections import defaultdict
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 39c3d5799a6563a..c8d21dfb62ed557 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -39,7 +39,7 @@ def testFastMathFlags():
# CHECK-LABEL: TEST: testArithValue
@run
def testArithValue():
- def _binary_op(lhs, rhs, op: str):
+ def _binary_op(lhs, rhs, op: str) -> "ArithValue":
op = op.capitalize()
if arith._is_float_type(lhs.type):
op += "F"
@@ -71,7 +71,6 @@ def __str__(self):
f16_t = F16Type.get()
f32_t = F32Type.get()
f64_t = F64Type.get()
- i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 3d4cd087fbfed8f..ef5ef615e63c13c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -426,6 +426,12 @@ def __str__(self):
# And it should be equal to the in-tree concrete type
assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
+ d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result
+ # CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>)
+ print(d)
+ # CHECK: <importlib._bootstrap.TestTensorValue object at
+ print(repr(d))
+
# CHECK-LABEL: TEST: inferReturnTypeComponents
@run
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 1c3e1a6ae9654fe..f665f05c9f1a882 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -3,6 +3,7 @@
import gc
from mlir.ir import *
from mlir.dialects import func
+from mlir.dialects._ods_common import SubClassValueT
def run(f):
@@ -297,7 +298,7 @@ def __str__(self):
return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
@register_value_caster(IntegerType.static_typeid)
- def cast_int(v):
+ def cast_int(v) -> SubClassValueT:
print("in caster", v.__class__.__name__)
if isinstance(v, OpResult):
return NOPResult(v)
@@ -318,7 +319,10 @@ def cast_int(v):
print("result", values[0].result_number, values[0])
# CHECK: in caster OpResult
# CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
- print("result", values[1].result_number, values[0])
+ print("result", values[1].result_number, values[1])
+
+ # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("results slice", values[:1][0].result_number, values[:1][0])
value0, value1 = values
# CHECK: in caster OpResult
@@ -326,7 +330,7 @@ def cast_int(v):
print("result", value0.result_number, values[0])
# CHECK: in caster OpResult
# CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
- print("result", value1.result_number, values[0])
+ print("result", value1.result_number, values[1])
op1 = Operation.create("custom.op2", operands=[value0, value1])
# CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
@@ -348,8 +352,25 @@ def reduction(arg0, arg1):
# CHECK: as func arg 1 NOPBlockArg
print("as func arg", arg1.arg_number, arg1.__class__.__name__)
+ # CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0)
+ print(
+ "args slice",
+ reduction.func_op.arguments[:1][0].arg_number,
+ reduction.func_op.arguments[:1][0],
+ )
+
+ try:
+
+ @register_value_caster(IntegerType.static_typeid)
+ def dont_cast_int_shouldnt_register(v):
+ ...
+
+ except RuntimeError as e:
+ # CHECK: Value caster is already registered: <function testValueCasters.<locals>.cast_int at
+ print(e)
+
@register_value_caster(IntegerType.static_typeid, replace=True)
- def dont_cast_int(v):
+ def dont_cast_int(v) -> Value:
print("don't cast", v.result_number, v)
return v
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f533082a0a147c0..1e584343d0f0a85 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -42,6 +42,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
return cls(mlirPythonTestTestAttributeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
+
mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
mlirPythonTestTestTypeGetTypeID)
.def_classmethod(
@@ -50,7 +51,8 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
- auto cls =
+
+ auto typeCls =
mlir_type_subclass(m, "TestIntegerRankedTensorType",
mlirTypeIsARankedIntegerTensor,
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
@@ -65,16 +67,38 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
encoding));
},
"cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
- assert(py::hasattr(cls.get_class(), "static_typeid") &&
+
+ assert(py::hasattr(typeCls.get_class(), "static_typeid") &&
"TestIntegerRankedTensorType has no static_typeid");
- MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID();
+
+ MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
+
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
- mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) {
- return cls.get_class()(mlirType);
+ mlirRankedTensorTypeID,
+ pybind11::cpp_function([typeCls](const py::object &mlirType) {
+ return typeCls.get_class()(mlirType);
}),
/*replace=*/true);
- mlir_value_subclass(m, "TestTensorValue",
- mlirTypeIsAPythonTestTestTensorValue)
- .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+
+ auto valueCls = mlir_value_subclass(m, "TestTensorValue",
+ mlirTypeIsAPythonTestTestTensorValue)
+ .def("is_null", [](MlirValue &self) {
+ return mlirValueIsNull(self);
+ });
+
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
+ mlirRankedTensorTypeID)(
+ pybind11::cpp_function([valueCls](const py::object &valueObj) {
+ py::object capsule = mlirApiObjectToCapsule(valueObj);
+ MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
+ MlirType t = mlirValueGetType(v);
+ if (mlirShapedTypeHasStaticShape(t) &&
+ mlirShapedTypeGetDimSize(t, 0) == 1 &&
+ mlirShapedTypeGetDimSize(t, 1) == 2 &&
+ mlirShapedTypeGetDimSize(t, 2) == 3)
+ return valueCls.get_class()(valueObj);
+ return valueObj;
+ }));
}
>From 1e1b53ee97674d64e190739dd314a61988591da7 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 31 Oct 2023 14:00:16 -0500
Subject: [PATCH 6/7] remove valuecastercache
---
mlir/lib/Bindings/Python/Globals.h | 4 ---
mlir/lib/Bindings/Python/IRModule.cpp | 38 +++++----------------------
2 files changed, 6 insertions(+), 36 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index b06b39fb7515781..35555d614d811cd 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -120,12 +120,8 @@ class PyGlobals {
llvm::StringMap<pybind11::object> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
-
/// Map of MlirTypeID to custom value caster.
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
- /// Cache for map of MlirTypeID to custom value caster.
- llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMapCache;
-
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 45f2df9e64190e8..c1a632624ba5369 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -89,14 +89,10 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
pybind11::function valueCaster,
bool replace) {
pybind11::object &found = valueCasterMap[mlirTypeID];
- if (found && !found.is_none() && !replace)
+ if (found && !replace)
throw std::runtime_error("Value caster is already registered: " +
py::repr(found).cast<std::string>());
found = std::move(valueCaster);
- const auto foundIt = valueCasterMapCache.find(mlirTypeID);
- if (foundIt != valueCasterMapCache.end() && !foundIt->second.is_none()) {
- valueCasterMapCache[mlirTypeID] = found;
- }
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
@@ -145,35 +141,13 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
- {
- // Fast match against the value caster map first (common case).
- const auto foundIt = valueCasterMapCache.find(mlirTypeID);
- if (foundIt != valueCasterMapCache.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::function is defined");
- return foundIt->second;
- }
- }
-
- // Not found. Load the dialect namespace.
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
-
- // Attempt to find from the canonical map and cache.
- {
- const auto foundIt = valueCasterMap.find(mlirTypeID);
- if (foundIt != valueCasterMap.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::object is defined");
- // Positive cache.
- valueCasterMapCache[mlirTypeID] = foundIt->second;
- return foundIt->second;
- }
- // Negative cache.
- valueCasterMap[mlirTypeID] = py::none();
- return std::nullopt;
+ const auto foundIt = valueCasterMap.find(mlirTypeID);
+ if (foundIt != valueCasterMap.end()) {
+ assert(foundIt->second && "value caster is defined");
+ return foundIt->second;
}
+ return std::nullopt;
}
std::optional<py::object>
>From f5f98eb72b19c894f2e108c08b0cc59e6058c43f Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 31 Oct 2023 15:30:03 -0500
Subject: [PATCH 7/7] use detection idiom inside of `getItem` instead of
virtual member
---
mlir/lib/Bindings/Python/IRCore.cpp | 21 ---------------------
mlir/lib/Bindings/Python/IRModule.h | 3 ++-
mlir/lib/Bindings/Python/PybindUtils.h | 17 ++++++++++++++---
mlir/test/python/dialects/arith_dialect.py | 3 +++
mlir/test/python/ir/value.py | 3 ++-
5 files changed, 21 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 53eb75f810c1845..5f6b7d380bc023b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2218,13 +2218,6 @@ class PyBlockArgumentList
step),
operation(std::move(operation)), block(block) {}
- pybind11::object getItem(intptr_t index) override {
- auto item = this->SliceableT::getItem(index);
- if (item.ptr() != nullptr)
- return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
- return item;
- }
-
static void bindDerived(ClassTy &c) {
c.def_property_readonly("types", [](PyBlockArgumentList &self) {
return getValueTypes(self, self.operation->getContext());
@@ -2274,13 +2267,6 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
step),
operation(operation) {}
- pybind11::object getItem(intptr_t index) override {
- auto item = this->SliceableT::getItem(index);
- if (item.ptr() != nullptr)
- return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
- return item;
- }
-
void dunderSetItem(intptr_t index, PyValue value) {
index = wrapIndex(index);
mlirOperationSetOperand(operation->get(), index, value.get());
@@ -2337,13 +2323,6 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
step),
operation(std::move(operation)) {}
- pybind11::object getItem(intptr_t index) override {
- auto item = this->SliceableT::getItem(index);
- if (item.ptr() != nullptr)
- return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
- return item;
- }
-
static void bindDerived(ClassTy &c) {
c.def_property_readonly("types", [](PyOpResultList &self) {
return getValueTypes(self, self.operation->getContext());
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index b95c4578fbc2220..c6442ce8fefde4a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -1110,6 +1110,7 @@ class PyConcreteAttribute : public BaseTy {
/// bindings so such operation always exists).
class PyValue {
public:
+ virtual ~PyValue() = default;
PyValue(PyOperationRef parentOperation, MlirValue value)
: parentOperation(std::move(parentOperation)), value(value) {}
operator MlirValue() const { return value; }
@@ -1122,7 +1123,7 @@ class PyValue {
/// Gets a capsule wrapping the void* within the MlirValue.
pybind11::object getCapsule();
- virtual pybind11::object maybeDownCast();
+ pybind11::object maybeDownCast();
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index efb7b713f80a40c..38462ac8ba6db9c 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -10,6 +10,7 @@
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
#include "mlir-c/Support.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/DataTypes.h"
@@ -228,10 +229,15 @@ class Sliceable {
return linearIndex;
}
+ /// Trait to check if T provides a `maybeDownCast` method.
+ /// Note, you need the & to detect inherited members.
+ template <typename T, typename... Args>
+ using has_maybe_downcast = decltype(&T::maybeDownCast);
+
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
- virtual pybind11::object getItem(intptr_t index) {
+ pybind11::object getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
@@ -239,8 +245,13 @@ class Sliceable {
return {};
}
- return pybind11::cast(
- static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
+ if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
+ return static_cast<Derived *>(this)
+ ->getRawElement(linearizeIndex(index))
+ .maybeDownCast();
+ else
+ return pybind11::cast(
+ static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
}
/// Returns a new instance of the pseudo-container restricted to the given
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index c8d21dfb62ed557..25a258f3e368814 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -74,6 +74,9 @@ def __str__(self):
with InsertionPoint(module.body):
a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+ # CHECK: ArithValue(%cst = arith.constant 4.240
+ print(a)
+
b = a + a
# CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
print(b)
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index f665f05c9f1a882..3723c00785f039c 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -370,7 +370,8 @@ def dont_cast_int_shouldnt_register(v):
print(e)
@register_value_caster(IntegerType.static_typeid, replace=True)
- def dont_cast_int(v) -> Value:
+ def dont_cast_int(v) -> OpResult:
+ assert isinstance(v, OpResult)
print("don't cast", v.result_number, v)
return v
More information about the Mlir-commits
mailing list