[Mlir-commits] [mlir] [mlir][python] fix `replace=True` for `register_operation` and `register_type_caster` (PR #70264)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 25 17:35:44 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
<img src="https://github.com/llvm/llvm-project/assets/5657668/443852b6-ac25-45bb-a38b-5dfbda09d5a7" height="400" />
<p></p>
So turns out that none of the `replace=True` things actually work because of the map caches (except for `register_attribute_builder(replace=True)`, which doesn't use such a cache). This was hidden by a series of unfortunate events:
1. `register_type_caster` failure was hidden because it was the same `TestIntegerRankedTensorType` being replaced with itself (d'oh).
2. `register_operation` failure was hidden behind the "order of events" in the lifecycle of typical extension import/use. Since extensions are loaded/registered almost immediately after generated builders are registered, there is no opportunity for the `operationClassMapCache` to be populated (through e.g., `module.body.operations[2]` or `module.body.operations[2].opview` or something). Of course as soon as you as actually do "late-bind/late-register" the extension, you see it's not successfully replacing the stale one in `operationClassMapCache`.
I'll take this opportunity to propose we ditch the caches all together. I've been cargo-culting them but I really don't understand how they work. There's this comment above `operationClassMapCache`
```cpp
/// Cache of operation name to external operation class object. This is
/// maintained on lookup as a shadow of operationClassMap in order for repeat
/// lookups of the classes to only incur the cost of one hashtable lookup.
llvm::StringMap<pybind11::object> operationClassMapCache;
```
But I don't understand how that's true given that the canonical thing `operationClassMap` is already a map:
```cpp
/// Map of full operation name to external operation class object.
llvm::StringMap<pybind11::object> operationClassMap;
```
Maybe it wasn't always the case? Anyway things work now but it seems like an unnecessary layer of complexity for not much gain? But maybe I'm wrong.
---
Full diff: https://github.com/llvm/llvm-project/pull/70264.diff
3 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRModule.cpp (+8)
- (modified) mlir/test/python/dialects/python_test.py (+11)
- (modified) mlir/test/python/ir/operation.py (+24)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index a1c8ab7a09ce155..f8e22f7bb0c1ba7 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -82,6 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
if (found && !found.is_none() && !replace)
throw std::runtime_error("Type caster is already registered");
found = std::move(typeCaster);
+ const auto foundIt = typeCasterMapCache.find(mlirTypeID);
+ if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
+ typeCasterMapCache[mlirTypeID] = found;
+ }
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
@@ -104,6 +108,10 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
.str());
}
found = std::move(pyClass);
+ auto foundIt = operationClassMapCache.find(operationName);
+ if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
+ operationClassMapCache[operationName] = found;
+ }
}
std::optional<py::function>
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 651e6554eebe8bd..a70b6fd5e5e4d84 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -510,6 +510,17 @@ def type_caster(pytype):
except RuntimeError as e:
print(e)
+ def type_caster(pytype):
+ return RankedTensorType(pytype)
+
+ register_type_caster(c.typeid, type_caster, replace=True)
+
+ d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
+ # CHECK: tensor<10x10xi5>
+ print(d.type)
+ # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
+ print("ranked tensor type", repr(d.type))
+
def type_caster(pytype):
return test.TestIntegerRankedTensorType(pytype)
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 129b7fa744e4721..5ded4814e54bf66 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -5,6 +5,8 @@
import itertools
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.dialects._ods_common import _cext
def run(f):
@@ -646,6 +648,7 @@ def testKnownOpView():
%1 = "custom.f32"() : () -> f32
%2 = "custom.f32"() : () -> f32
%3 = arith.addf %1, %2 : f32
+ %4 = arith.constant 0 : i32
"""
)
print(module)
@@ -668,6 +671,27 @@ def testKnownOpView():
# CHECK: OpView object
print(repr(custom))
+ # constant should map to an extension OpView class in the arithmetic dialect.
+ constant = module.body.operations[3]
+ # CHECK: <mlir.dialects.arith.ConstantOp object
+ print(repr(constant))
+ # CHECK: literal value 0
+ print("literal value", constant.literal_value)
+
+ @_cext.register_operation(arith._Dialect, replace=True)
+ class ConstantOp(arith.ConstantOp):
+ def __init__(self, result, value, *, loc=None, ip=None):
+ if isinstance(value, int):
+ super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
+ elif isinstance(value, float):
+ super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+ else:
+ super().__init__(value, loc=loc, ip=ip)
+
+ constant = module.body.operations[3]
+ # CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object
+ print(repr(constant))
+
# CHECK-LABEL: TEST: testSingleResultProperty
@run
``````````
</details>
https://github.com/llvm/llvm-project/pull/70264
More information about the Mlir-commits
mailing list