[Mlir-commits] [mlir] [mlir][python] fix typecaster replace (PR #70264)

Maksim Levental llvmlistbot at llvm.org
Wed Oct 25 15:17:02 PDT 2023


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/70264

None

>From 9944524c52f1896d9036f6f89efd7d9375751483 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 25 Oct 2023 17:14:08 -0500
Subject: [PATCH] [mlir][python] fix typecaster replace

---
 mlir/lib/Bindings/Python/IRModule.cpp    | 4 ++++
 mlir/test/python/dialects/python_test.py | 4 ++--
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index a1c8ab7a09ce155..1067d6e3b94f645 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -82,6 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
   if (found && !found.is_none() && !replace)
     throw std::runtime_error("Type caster is already registered");
   found = std::move(typeCaster);
+  const auto foundIt = typeCasterMapCache.find(mlirTypeID);
+  if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
+    typeCasterMapCache[mlirTypeID] = found;
+  }
 }
 
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 651e6554eebe8bd..901767d5f34a4dc 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -511,14 +511,14 @@ def type_caster(pytype):
             print(e)
 
         def type_caster(pytype):
-            return test.TestIntegerRankedTensorType(pytype)
+            return RankedTensorType(pytype)
 
         register_type_caster(c.typeid, type_caster, replace=True)
 
         d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
         # CHECK: tensor<10x10xi5>
         print(d.type)
-        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
+        # CHECK: RankedTensorType(tensor<10x10xi5>)
         print(repr(d.type))
 
 



More information about the Mlir-commits mailing list