[Mlir-commits] [mlir] b0e00ca - [mlir][python] fix `replace=True` for `register_operation` and `register_type_caster` (#70264)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 30 18:22:32 PDT 2023


Author: Maksim Levental
Date: 2023-10-30T20:22:27-05:00
New Revision: b0e00ca6a605b88e83129c8c6be4e177f93cbfea

URL: https://github.com/llvm/llvm-project/commit/b0e00ca6a605b88e83129c8c6be4e177f93cbfea
DIFF: https://github.com/llvm/llvm-project/commit/b0e00ca6a605b88e83129c8c6be4e177f93cbfea.diff

LOG: [mlir][python] fix `replace=True` for `register_operation` and `register_type_caster` (#70264)

<img
src="https://github.com/llvm/llvm-project/assets/5657668/443852b6-ac25-45bb-a38b-5dfbda09d5a7"
height="400" />
<p></p>


So turns out that none of the `replace=True` things actually work
because of the map caches (except for
`register_attribute_builder(replace=True)`, which doesn't use such a
cache). This was hidden by a series of unfortunate events:

1. `register_type_caster` failure was hidden because it was the same
`TestIntegerRankedTensorType` being replaced with itself (d'oh).
2. `register_operation` failure was hidden behind the "order of events"
in the lifecycle of typical extension import/use. Since extensions are
loaded/registered almost immediately after generated builders are
registered, there is no opportunity for the `operationClassMapCache` to
be populated (through e.g., `module.body.operations[2]` or
`module.body.operations[2].opview` or something). Of course as soon as
you as actually do "late-bind/late-register" the extension, you see it's
not successfully replacing the stale one in `operationClassMapCache`.

I'll take this opportunity to propose we ditch the caches all together.
I've been cargo-culting them but I really don't understand how they
work. There's this comment above `operationClassMapCache`

```cpp
  /// Cache of operation name to external operation class object. This is
  /// maintained on lookup as a shadow of operationClassMap in order for repeat
  /// lookups of the classes to only incur the cost of one hashtable lookup.
  llvm::StringMap<pybind11::object> operationClassMapCache;
```

But I don't understand how that's true given that the canonical thing
`operationClassMap` is already a map:

```cpp
  /// Map of full operation name to external operation class object.
  llvm::StringMap<pybind11::object> operationClassMap;
```

Maybe it wasn't always the case? Anyway things work now but it seems
like an unnecessary layer of complexity for not much gain? But maybe I'm
wrong.

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModule.cpp
    mlir/test/python/dialects/python_test.py
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index a1c8ab7a09ce155..f8e22f7bb0c1ba7 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -82,6 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
   if (found && !found.is_none() && !replace)
     throw std::runtime_error("Type caster is already registered");
   found = std::move(typeCaster);
+  const auto foundIt = typeCasterMapCache.find(mlirTypeID);
+  if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
+    typeCasterMapCache[mlirTypeID] = found;
+  }
 }
 
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
@@ -104,6 +108,10 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
                                  .str());
   }
   found = std::move(pyClass);
+  auto foundIt = operationClassMapCache.find(operationName);
+  if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
+    operationClassMapCache[operationName] = found;
+  }
 }
 
 std::optional<py::function>

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 651e6554eebe8bd..3d4cd087fbfed8f 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -510,6 +510,19 @@ def type_caster(pytype):
         except RuntimeError as e:
             print(e)
 
+        def type_caster(pytype):
+            return RankedTensorType(pytype)
+
+        # python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
+        # So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
+        register_type_caster(c.typeid, type_caster, replace=True)
+
+        d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
+        # CHECK: tensor<10x10xi5>
+        print(d.type)
+        # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
+        print("ranked tensor type", repr(d.type))
+
         def type_caster(pytype):
             return test.TestIntegerRankedTensorType(pytype)
 

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 129b7fa744e4721..04239b048c1c641 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -5,6 +5,8 @@
 import itertools
 from mlir.ir import *
 from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.dialects._ods_common import _cext
 
 
 def run(f):
@@ -646,6 +648,7 @@ def testKnownOpView():
       %1 = "custom.f32"() : () -> f32
       %2 = "custom.f32"() : () -> f32
       %3 = arith.addf %1, %2 : f32
+      %4 = arith.constant 0 : i32
     """
         )
         print(module)
@@ -668,6 +671,31 @@ def testKnownOpView():
         # CHECK: OpView object
         print(repr(custom))
 
+        # constant should map to an extension OpView class in the arithmetic dialect.
+        constant = module.body.operations[3]
+        # CHECK: <mlir.dialects.arith.ConstantOp object
+        print(repr(constant))
+        # Checks that the arith extension is being registered successfully
+        # (literal_value is a property on the extension class but not on the default OpView).
+        # CHECK: literal value 0
+        print("literal value", constant.literal_value)
+
+        # Checks that "late" registration/replacement (i.e., post all module loading/initialization)
+        # is working correctly.
+        @_cext.register_operation(arith._Dialect, replace=True)
+        class ConstantOp(arith.ConstantOp):
+            def __init__(self, result, value, *, loc=None, ip=None):
+                if isinstance(value, int):
+                    super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
+                elif isinstance(value, float):
+                    super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+                else:
+                    super().__init__(value, loc=loc, ip=ip)
+
+        constant = module.body.operations[3]
+        # CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object
+        print(repr(constant))
+
 
 # CHECK-LABEL: TEST: testSingleResultProperty
 @run


        


More information about the Mlir-commits mailing list