[Mlir-commits] [mlir] [mlir][python] fix typecaster replace (PR #70264)
Maksim Levental
llvmlistbot at llvm.org
Wed Oct 25 16:04:32 PDT 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/70264
>From edf0b73cd40667f882b3cd498c3d57e018c3e234 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 25 Oct 2023 17:14:08 -0500
Subject: [PATCH] [mlir][python] fix typecaster replace
---
mlir/lib/Bindings/Python/IRModule.cpp | 4 ++++
mlir/test/python/dialects/python_test.py | 6 +++---
mlir/test/python/ir/operation.py | 8 ++++++++
3 files changed, 15 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index a1c8ab7a09ce155..1067d6e3b94f645 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -82,6 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
if (found && !found.is_none() && !replace)
throw std::runtime_error("Type caster is already registered");
found = std::move(typeCaster);
+ const auto foundIt = typeCasterMapCache.find(mlirTypeID);
+ if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
+ typeCasterMapCache[mlirTypeID] = found;
+ }
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 651e6554eebe8bd..d726b2966b71c6c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -511,15 +511,15 @@ def type_caster(pytype):
print(e)
def type_caster(pytype):
- return test.TestIntegerRankedTensorType(pytype)
+ return RankedTensorType(pytype)
register_type_caster(c.typeid, type_caster, replace=True)
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
# CHECK: tensor<10x10xi5>
print(d.type)
- # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
- print(repr(d.type))
+ # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
+ print("ranked tensor type", repr(d.type))
# CHECK-LABEL: TEST: testInferTypeOpInterface
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 129b7fa744e4721..e251ec702474f1b 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -646,6 +646,7 @@ def testKnownOpView():
%1 = "custom.f32"() : () -> f32
%2 = "custom.f32"() : () -> f32
%3 = arith.addf %1, %2 : f32
+ %4 = arith.constant 0 : i32
"""
)
print(module)
@@ -668,6 +669,13 @@ def testKnownOpView():
# CHECK: OpView object
print(repr(custom))
+ # constant should map to an extension OpView class in the arithmetic dialect.
+ constant = module.body.operations[3]
+ # CHECK: <mlir.dialects.arith.ConstantOp object
+ print(repr(constant))
+ # CHECK: literal value 0
+ print("literal value", constant.literal_value)
+
# CHECK-LABEL: TEST: testSingleResultProperty
@run
More information about the Mlir-commits
mailing list