[Mlir-commits] [mlir] e2d2b23 - [MLIR][Python] Add `convert_type` API for TypeConverter (#183561)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 26 08:30:07 PST 2026


Author: Twice
Date: 2026-02-27T00:30:00+08:00
New Revision: e2d2b23a5aeb6292fd135f7677a8655800ea7c20

URL: https://github.com/llvm/llvm-project/commit/e2d2b23a5aeb6292fd135f7677a8655800ea7c20
DIFF: https://github.com/llvm/llvm-project/commit/e2d2b23a5aeb6292fd135f7677a8655800ea7c20.diff

LOG: [MLIR][Python] Add `convert_type` API for TypeConverter (#183561)

This PR adds the  `convert_type` API for `TypeConverter`.

Added: 
    

Modified: 
    mlir/include/mlir-c/Rewrite.h
    mlir/lib/Bindings/Python/Rewrite.cpp
    mlir/lib/CAPI/Transforms/Rewrite.cpp
    mlir/test/python/rewrite.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index b4f93fd5a9b78..5e952edad23cb 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -546,6 +546,10 @@ mlirTypeConverterAddConversion(MlirTypeConverter typeConverter,
                                MlirTypeConverterConversionCallback convertType,
                                void *userData);
 
+/// Convert the given type using the given TypeConverter.
+MLIR_CAPI_EXPORTED MlirType
+mlirTypeConverterConvertType(MlirTypeConverter typeConverter, MlirType type);
+
 //===----------------------------------------------------------------------===//
 /// ConversionPattern API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index e370552c00a9a..6c414e1a4c023 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -101,6 +101,15 @@ class PyTypeConverter {
         convert.ptr());
   }
 
+  nb::typed<nb::object, std::optional<PyType>> convertType(PyType &type) {
+    MlirType converted = mlirTypeConverterConvertType(typeConverter, type);
+    if (mlirTypeIsNull(converted))
+      return nb::none();
+    return PyType(PyMlirContext::forContext(mlirTypeGetContext(converted)),
+                  converted)
+        .maybeDownCast();
+  }
+
   MlirTypeConverter get() { return typeConverter; }
 
 private:
@@ -621,7 +630,9 @@ void populateRewriteSubmodule(nb::module_ &m) {
   nb::class_<PyTypeConverter>(m, "TypeConverter")
       .def(nb::init<>(), "Create a new TypeConverter.")
       .def("add_conversion", &PyTypeConverter::addConversion, "convert"_a,
-           nb::keep_alive<0, 1>(), "Register a type conversion function.");
+           nb::keep_alive<0, 1>(), "Register a type conversion function.")
+      .def("convert_type", &PyTypeConverter::convertType, "type"_a,
+           "Convert the given type. Returns None if conversion fails.");
 
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule

diff  --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 5900f08ae1730..a7e43254767ad 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -590,6 +590,11 @@ void mlirTypeConverterAddConversion(
           });
 }
 
+MlirType mlirTypeConverterConvertType(MlirTypeConverter typeConverter,
+                                      MlirType type) {
+  return wrap(unwrap(typeConverter)->convertType(unwrap(type)));
+}
+
 //===----------------------------------------------------------------------===//
 /// ConversionPattern API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 1a9bcc87a3bda..35d88833e69e0 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -318,3 +318,12 @@ def convert_muli(op, adaptor, type_converter, rewriter):
             # CHECK: caught exception: partial conversion failed
             # CHECK: failed to legalize unresolved materialization
             print("caught exception:", e)
+
+        t1 = converter.convert_type(IntegerType.get_signless(64))
+        # CHECK: IntType
+        print(type(t1))
+        # CHECK: !smt.int
+        print(str(t1))
+        t2 = converter.convert_type(F32Type.get())
+        # CHECK: None
+        print(t2)


        


More information about the Mlir-commits mailing list