[Mlir-commits] [mlir] 3f66484 - [MLIR][Python] Fix typeid support for DynamicType and DynamicAttr (#183076)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 25 05:58:18 PST 2026


Author: Twice
Date: 2026-02-25T21:58:13+08:00
New Revision: 3f6648422c65da120ebd12474cfc2cf7d3f6b655

URL: https://github.com/llvm/llvm-project/commit/3f6648422c65da120ebd12474cfc2cf7d3f6b655
DIFF: https://github.com/llvm/llvm-project/commit/3f6648422c65da120ebd12474cfc2cf7d3f6b655.diff

LOG: [MLIR][Python] Fix typeid support for DynamicType and DynamicAttr (#183076)

Previously, we were using the static `typeid` of `DynamicType` for
checks, which is incorrect. We should instead check against the `typeid`
of `DynamicTypeDefinition` (which is a subclass of `SelfOwningTypeID`),
and register it via `register_type_caster` so that Python-defined types
can use `maybe_downcast`. (The attribute part is same.)

Added: 
    

Modified: 
    mlir/include/mlir-c/ExtensibleDialect.h
    mlir/include/mlir/Bindings/Python/IRAttributes.h
    mlir/include/mlir/Bindings/Python/IRTypes.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/CAPI/IR/ExtensibleDialect.cpp
    mlir/python/mlir/dialects/ext.py
    mlir/test/python/dialects/ext.py
    mlir/test/python/dialects/irdl.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index d6aa8181c024c..fee26772e1560 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -87,9 +87,6 @@ mlirExtensibleDialectLookupTypeDefinition(MlirDialect dialect,
 /// Check if the given type is a dynamic type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsADynamicType(MlirType type);
 
-/// Get the type ID of a dynamic type.
-MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicTypeGetTypeID(void);
-
 /// Get a dynamic type by instantiating the given type definition with the
 /// provided attributes.
 MLIR_CAPI_EXPORTED MlirType mlirDynamicTypeGet(
@@ -106,6 +103,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDynamicTypeGetParam(MlirType type,
 MLIR_CAPI_EXPORTED MlirDynamicTypeDefinition
 mlirDynamicTypeGetTypeDef(MlirType type);
 
+/// Get the type ID of a dynamic type definition.
+MLIR_CAPI_EXPORTED MlirTypeID
+mlirDynamicTypeDefinitionGetTypeID(MlirDynamicTypeDefinition typeDef);
+
 /// Get the name of the given dynamic type definition.
 MLIR_CAPI_EXPORTED MlirStringRef
 mlirDynamicTypeDefinitionGetName(MlirDynamicTypeDefinition typeDef);
@@ -123,9 +124,6 @@ mlirExtensibleDialectLookupAttrDefinition(MlirDialect dialect,
 /// Check if the given attribute is a dynamic attribute.
 MLIR_CAPI_EXPORTED bool mlirAttributeIsADynamicAttr(MlirAttribute attr);
 
-/// Get the type ID of a dynamic attribute.
-MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicAttrGetTypeID(void);
-
 /// Get a dynamic attribute by instantiating the given attribute definition with
 /// the provided attributes.
 MLIR_CAPI_EXPORTED MlirAttribute mlirDynamicAttrGet(
@@ -142,6 +140,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDynamicAttrGetParam(MlirAttribute attr,
 MLIR_CAPI_EXPORTED MlirDynamicAttrDefinition
 mlirDynamicAttrGetAttrDef(MlirAttribute attr);
 
+/// Get the type ID of a dynamic attribute definition.
+MLIR_CAPI_EXPORTED MlirTypeID
+mlirDynamicAttrDefinitionGetTypeID(MlirDynamicAttrDefinition attrDef);
+
 /// Get the name of the given dynamic attribute definition.
 MLIR_CAPI_EXPORTED MlirStringRef
 mlirDynamicAttrDefinitionGetName(MlirDynamicAttrDefinition attrDef);

diff  --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
index 2f1e4a2ad99d0..173674e0091d2 100644
--- a/mlir/include/mlir/Bindings/Python/IRAttributes.h
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -594,8 +594,6 @@ class MLIR_PYTHON_API_EXPORTED PyDynamicAttribute
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADynamicAttr;
   static constexpr const char *pyClassName = "DynamicAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirDynamicAttrGetTypeID;
 
   static void bindDerived(ClassTy &c);
 };

diff  --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index b29483ef78e5d..8e81ee9805200 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -450,8 +450,6 @@ class MLIR_PYTHON_API_EXPORTED PyDynamicType
     : public PyConcreteType<PyDynamicType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsADynamicType;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirDynamicTypeGetTypeID;
   static constexpr const char *pyClassName = "DynamicType";
   using PyConcreteType::PyConcreteType;
 

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 50eb336c9201c..0536cffcf100c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -18,6 +18,7 @@
 
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/ExtensibleDialect.h"
 #include "mlir/Bindings/Python/IRAttributes.h"
 #include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
@@ -1376,41 +1377,47 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
       "Returns the value of the string attribute as `bytes`");
 }
 
+static MlirDynamicAttrDefinition
+getDynamicAttrDef(const std::string &fullAttrName,
+                  DefaultingPyMlirContext context) {
+  size_t dotPos = fullAttrName.find('.');
+  if (dotPos == std::string::npos) {
+    throw nb::value_error("Expected full attribute name to be in the format "
+                          "'<dialectName>.<attributeName>'.");
+  }
+
+  std::string dialectName = fullAttrName.substr(0, dotPos);
+  std::string attrName = fullAttrName.substr(dotPos + 1);
+  PyDialects dialects(context->getRef());
+  MlirDialect dialect = dialects.getDialectForKey(dialectName, false);
+  if (!mlirDialectIsAExtensibleDialect(dialect))
+    throw nb::value_error(
+        ("Dialect '" + dialectName + "' is not an extensible dialect.")
+            .c_str());
+
+  MlirDynamicAttrDefinition attrDef = mlirExtensibleDialectLookupAttrDefinition(
+      dialect, toMlirStringRef(attrName));
+  if (attrDef.ptr == nullptr) {
+    throw nb::value_error(("Dialect '" + dialectName +
+                           "' does not contain an attribute named '" +
+                           attrName + "'.")
+                              .c_str());
+  }
+  return attrDef;
+}
+
 void PyDynamicAttribute::bindDerived(ClassTy &c) {
   c.def_static(
       "get",
       [](const std::string &fullAttrName, const std::vector<PyAttribute> &attrs,
          DefaultingPyMlirContext context) {
-        size_t dotPos = fullAttrName.find('.');
-        if (dotPos == std::string::npos) {
-          throw nb::value_error(
-              "Expected full attribute name to be in the format "
-              "'<dialectName>.<attributeName>'.");
-        }
-
-        std::string dialectName = fullAttrName.substr(0, dotPos);
-        std::string attrName = fullAttrName.substr(dotPos + 1);
-        PyDialects dialects(context->getRef());
-        MlirDialect dialect = dialects.getDialectForKey(dialectName, false);
-        if (!mlirDialectIsAExtensibleDialect(dialect))
-          throw nb::value_error(
-              ("Dialect '" + dialectName + "' is not an extensible dialect.")
-                  .c_str());
-
-        MlirDynamicAttrDefinition attrDef =
-            mlirExtensibleDialectLookupAttrDefinition(
-                dialect, toMlirStringRef(attrName));
-        if (attrDef.ptr == nullptr) {
-          throw nb::value_error(("Dialect '" + dialectName +
-                                 "' does not contain an attribute named '" +
-                                 attrName + "'.")
-                                    .c_str());
-        }
-
         std::vector<MlirAttribute> mlirAttrs;
         mlirAttrs.reserve(attrs.size());
         for (const auto &attr : attrs)
           mlirAttrs.push_back(attr.get());
+
+        MlirDynamicAttrDefinition attrDef =
+            getDynamicAttrDef(fullAttrName, context);
         MlirAttribute attr =
             mlirDynamicAttrGet(attrDef, mlirAttrs.data(), mlirAttrs.size());
         return PyDynamicAttribute(context->getRef(), attr);
@@ -1438,6 +1445,15 @@ void PyDynamicAttribute::bindDerived(ClassTy &c) {
     return std::string(dialectNamespace.data, dialectNamespace.length) + "." +
            std::string(name.data, name.length);
   });
+  c.def_static(
+      "lookup_typeid",
+      [](const std::string &fullAttrName, DefaultingPyMlirContext context) {
+        MlirDynamicAttrDefinition attrDef =
+            getDynamicAttrDef(fullAttrName, context);
+        return PyTypeID(mlirDynamicAttrDefinitionGetTypeID(attrDef));
+      },
+      nb::arg("full_attr_name"), nb::arg("context") = nb::none(),
+      "Look up the TypeID for the given dynamic attribute name.");
 }
 
 void populateIRAttributes(nb::module_ &m) {

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index aa448fa44d9bf..e04ff99b6c5fc 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -842,35 +842,43 @@ void PyOpaqueType::bindDerived(ClassTy &c) {
       "Returns the data for the Opaque type as a string.");
 }
 
+static MlirDynamicTypeDefinition
+getDynamicTypeDef(const std::string &fullTypeName,
+                  DefaultingPyMlirContext context) {
+  size_t dotPos = fullTypeName.find('.');
+  if (dotPos == std::string::npos) {
+    throw nb::value_error("Expected full type name to be in the format "
+                          "'<dialectName>.<typeName>'.");
+  }
+
+  std::string dialectName = fullTypeName.substr(0, dotPos);
+  std::string typeName = fullTypeName.substr(dotPos + 1);
+  PyDialects dialects(context->getRef());
+  MlirDialect dialect = dialects.getDialectForKey(dialectName, false);
+  if (!mlirDialectIsAExtensibleDialect(dialect))
+    throw nb::value_error(
+        ("Dialect '" + dialectName + "' is not an extensible dialect.")
+            .c_str());
+
+  MlirDynamicTypeDefinition typeDef = mlirExtensibleDialectLookupTypeDefinition(
+      dialect, toMlirStringRef(typeName));
+  if (typeDef.ptr == nullptr) {
+    throw nb::value_error(("Dialect '" + dialectName +
+                           "' does not contain a type named '" + typeName +
+                           "'.")
+                              .c_str());
+  }
+
+  return typeDef;
+}
+
 void PyDynamicType::bindDerived(ClassTy &c) {
   c.def_static(
       "get",
       [](const std::string &fullTypeName, const std::vector<PyAttribute> &attrs,
          DefaultingPyMlirContext context) {
-        size_t dotPos = fullTypeName.find('.');
-        if (dotPos == std::string::npos) {
-          throw nb::value_error("Expected full type name to be in the format "
-                                "'<dialectName>.<typeName>'.");
-        }
-
-        std::string dialectName = fullTypeName.substr(0, dotPos);
-        std::string typeName = fullTypeName.substr(dotPos + 1);
-        PyDialects dialects(context->getRef());
-        MlirDialect dialect = dialects.getDialectForKey(dialectName, false);
-        if (!mlirDialectIsAExtensibleDialect(dialect))
-          throw nb::value_error(
-              ("Dialect '" + dialectName + "' is not an extensible dialect.")
-                  .c_str());
-
         MlirDynamicTypeDefinition typeDef =
-            mlirExtensibleDialectLookupTypeDefinition(
-                dialect, toMlirStringRef(typeName));
-        if (typeDef.ptr == nullptr) {
-          throw nb::value_error(("Dialect '" + dialectName +
-                                 "' does not contain a type named '" +
-                                 typeName + "'.")
-                                    .c_str());
-        }
+            getDynamicTypeDef(fullTypeName, context);
 
         std::vector<MlirAttribute> mlirAttrs;
         mlirAttrs.reserve(attrs.size());
@@ -902,6 +910,15 @@ void PyDynamicType::bindDerived(ClassTy &c) {
     return std::string(dialectNamespace.data, dialectNamespace.length) + "." +
            std::string(name.data, name.length);
   });
+  c.def_static(
+      "lookup_typeid",
+      [](const std::string &fullTypeName, DefaultingPyMlirContext context) {
+        MlirDynamicTypeDefinition typeDef =
+            getDynamicTypeDef(fullTypeName, context);
+        return PyTypeID(mlirDynamicTypeDefinitionGetTypeID(typeDef));
+      },
+      nb::arg("full_type_name"), nb::arg("context") = nb::none(),
+      "Look up the TypeID for the given dynamic type name.");
 }
 
 void populateIRTypes(nb::module_ &m) {

diff  --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index 51cec5f95a201..60a7f7d9064eb 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -129,6 +129,11 @@ MlirDynamicTypeDefinition mlirDynamicTypeGetTypeDef(MlirType type) {
   return wrap(llvm::cast<mlir::DynamicType>(unwrap(type)).getTypeDef());
 }
 
+MlirTypeID
+mlirDynamicTypeDefinitionGetTypeID(MlirDynamicTypeDefinition typeDef) {
+  return wrap(unwrap(typeDef)->getTypeID());
+}
+
 MlirStringRef
 mlirDynamicTypeDefinitionGetName(MlirDynamicTypeDefinition typeDef) {
   return wrap(unwrap(typeDef)->getName());
@@ -176,6 +181,11 @@ MlirDynamicAttrDefinition mlirDynamicAttrGetAttrDef(MlirAttribute attr) {
   return wrap(llvm::cast<mlir::DynamicAttr>(unwrap(attr)).getAttrDef());
 }
 
+MlirTypeID
+mlirDynamicAttrDefinitionGetTypeID(MlirDynamicAttrDefinition attrDef) {
+  return wrap(unwrap(attrDef)->getTypeID());
+}
+
 MlirStringRef
 mlirDynamicAttrDefinitionGetName(MlirDynamicAttrDefinition attrDef) {
   return wrap(unwrap(attrDef)->getName());

diff  --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 077480ea437ea..39aacf32dabb9 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -678,3 +678,7 @@ def load(cls, register=True, reload=False) -> None:
             register_dialect_operation = register_operation(cls)
             for op in cls.operations:
                 register_dialect_operation(op)
+
+            for type_ in cls.types:
+                typeid = ir.DynamicType.lookup_typeid(type_.type_name)
+                _cext.register_type_caster(typeid)(type_)

diff  --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index fdd2a8a7cdb5b..f9252bad37a39 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -581,6 +581,9 @@ class MakeArray3Op(TestType.Operation, name="make_array3"):
         # CHECK: 6 : i32
         print(a6.length)
 
+        # CHECK: <locals>.Array
+        print(type(Type(a4).maybe_downcast()))
+
         module = Module.create()
         with InsertionPoint(module.body):
             MakeArrayOp(a4)

diff  --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
index 004a795511f75..fce5ad212e374 100644
--- a/mlir/test/python/dialects/irdl.py
+++ b/mlir/test/python/dialects/irdl.py
@@ -134,6 +134,22 @@ def testIRDLTypes():
         # CHECK: unit
         t2.params[1].dump()
 
+        # CHECK: True
+        print(
+            t2.typeid == DynamicType.lookup_typeid("irdl_type_test.type2"),
+            file=sys.stderr,
+        )
+        # CHECK: False
+        print(
+            t1.typeid == DynamicType.lookup_typeid("irdl_type_test.type2"),
+            file=sys.stderr,
+        )
+        # CHECK: True
+        print(
+            t1.typeid == DynamicType.lookup_typeid("irdl_type_test.type1"),
+            file=sys.stderr,
+        )
+
         m = Module.create()
         with InsertionPoint(m.body):
             Operation.create("irdl_type_test.op1", results=[t1])
@@ -211,6 +227,22 @@ def testIRDLAttrs():
         # CHECK: unit
         a2.params[1].dump()
 
+        # CHECK: True
+        print(
+            a2.typeid == DynamicAttr.lookup_typeid("irdl_attr_test.attr2"),
+            file=sys.stderr,
+        )
+        # CHECK: False
+        print(
+            a1.typeid == DynamicAttr.lookup_typeid("irdl_attr_test.attr2"),
+            file=sys.stderr,
+        )
+        # CHECK: True
+        print(
+            a1.typeid == DynamicAttr.lookup_typeid("irdl_attr_test.attr1"),
+            file=sys.stderr,
+        )
+
         m = Module.create()
         with InsertionPoint(m.body):
             Operation.create("irdl_attr_test.op1", attributes={"attr": a1})


        


More information about the Mlir-commits mailing list