[Mlir-commits] [mlir] [mlir][Python] use canonical Python `isinstance` instead of `Type.isinstance` (PR #172892)

Maksim Levental llvmlistbot at llvm.org
Mon Jan 5 12:58:02 PST 2026


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/172892

>From da8faccfdecbc7b4e08014df8131c70f50f3cec4 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 18 Dec 2025 11:22:57 -0800
Subject: [PATCH 1/4] [mlir][Python] use canonical Python isinstance instead of
 Type.isinstance

---
 mlir/python/mlir/dialects/arith.py            |  24 +---
 .../dialects/linalg/opdsl/lang/emitter.py     | 107 +++++-------------
 mlir/python/mlir/dialects/memref.py           |  15 ++-
 mlir/test/python/dialects/arith_dialect.py    |   6 +-
 mlir/test/python/ir/auto_location.py          |   2 +-
 5 files changed, 48 insertions(+), 106 deletions(-)

diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 88e8502a29eae..555fb4c5ef3eb 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -21,26 +21,6 @@
     raise RuntimeError("Error loading imports from extension module") from e
 
 
-def _isa(obj: Any, cls: type):
-    try:
-        cls(obj)
-    except ValueError:
-        return False
-    return True
-
-
-def _is_any_of(obj: Any, classes: List[type]):
-    return any(_isa(obj, cls) for cls in classes)
-
-
-def _is_integer_like_type(type: Type):
-    return _is_any_of(type, [IntegerType, IndexType])
-
-
-def _is_float_type(type: Type):
-    return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
@@ -96,9 +76,9 @@ def value(self):
 
     @property
     def literal_value(self) -> Union[int, float]:
-        if _is_integer_like_type(self.type):
+        if isinstance(self.type, (IntegerType, IndexType)):
             return IntegerAttr(self.value).value
-        elif _is_float_type(self.type):
+        elif isinstance(self.type, FloatType):
             return FloatAttr(self.value).value
         else:
             raise ValueError("only integer and float constants have literal values")
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index fb2570c7bb498..126ebfca9e0bb 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -412,9 +412,9 @@ def _cast(
             )
         if operand.type == to_type:
             return operand
-        if _is_integer_type(to_type):
+        if isinstance(to_type, IntegerType):
             return self._cast_to_integer(to_type, operand, is_unsigned_cast)
-        elif _is_floating_point_type(to_type):
+        elif isinstance(to_type, FloatType):
             return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
 
     def _cast_to_integer(
@@ -422,11 +422,11 @@ def _cast_to_integer(
     ) -> Value:
         to_width = IntegerType(to_type).width
         operand_type = operand.type
-        if _is_floating_point_type(operand_type):
+        if isinstance(operand_type, FloatType):
             if is_unsigned_cast:
                 return arith.FPToUIOp(to_type, operand).result
             return arith.FPToSIOp(to_type, operand).result
-        if _is_index_type(operand_type):
+        if isinstance(operand_type, IndexType):
             return arith.IndexCastOp(to_type, operand).result
         # Assume integer.
         from_width = IntegerType(operand_type).width
@@ -444,13 +444,15 @@ def _cast_to_floating_point(
         self, to_type: Type, operand: Value, is_unsigned_cast: bool
     ) -> Value:
         operand_type = operand.type
-        if _is_integer_type(operand_type):
+        if isinstance(operand_type, IntegerType):
             if is_unsigned_cast:
                 return arith.UIToFPOp(to_type, operand).result
             return arith.SIToFPOp(to_type, operand).result
         # Assume FloatType.
-        to_width = _get_floating_point_width(to_type)
-        from_width = _get_floating_point_width(operand_type)
+        assert isinstance(to_type, FloatType)
+        assert isinstance(operand_type, FloatType)
+        to_width = to_type.width
+        from_width = operand_type.width
         if to_width > from_width:
             return arith.ExtFOp(to_type, operand).result
         elif to_width < from_width:
@@ -466,89 +468,85 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
         return self._cast(type_var_name, operand, True)
 
     def _unary_exp(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if isinstance(x.type, FloatType):
             return math.ExpOp(x).result
         raise NotImplementedError("Unsupported 'exp' operand: {x}")
 
     def _unary_log(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if isinstance(x.type, FloatType):
             return math.LogOp(x).result
         raise NotImplementedError("Unsupported 'log' operand: {x}")
 
     def _unary_abs(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if isinstance(x.type, FloatType):
             return math.AbsFOp(x).result
         raise NotImplementedError("Unsupported 'abs' operand: {x}")
 
     def _unary_ceil(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if isinstance(x.type, FloatType):
             return math.CeilOp(x).result
         raise NotImplementedError("Unsupported 'ceil' operand: {x}")
 
     def _unary_floor(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if isinstance(x.type, FloatType):
             return math.FloorOp(x).result
         raise NotImplementedError("Unsupported 'floor' operand: {x}")
 
     def _unary_negf(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if isinstance(x.type, FloatType):
             return arith.NegFOp(x).result
-        if _is_complex_type(x.type):
+        if isinstance(x.type, ComplexType):
             return complex.NegOp(x).result
         raise NotImplementedError("Unsupported 'negf' operand: {x}")
 
     def _binary_add(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if isinstance(lhs.type, FloatType):
             return arith.AddFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
             return arith.AddIOp(lhs, rhs).result
-        if _is_complex_type(lhs.type):
+        if isinstance(lhs.type, ComplexType):
             return complex.AddOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
 
     def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if isinstance(lhs.type, FloatType):
             return arith.SubFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
             return arith.SubIOp(lhs, rhs).result
-        if _is_complex_type(lhs.type):
+        if isinstance(lhs.type, ComplexType):
             return complex.SubOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
 
     def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if isinstance(lhs.type, FloatType):
             return arith.MulFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
             return arith.MulIOp(lhs, rhs).result
-        if _is_complex_type(lhs.type):
+        if isinstance(lhs.type, ComplexType):
             return complex.MulOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
 
     def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if isinstance(lhs.type, FloatType):
             return arith.MaximumFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
             return arith.MaxSIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
 
     def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
-        if (
-            _is_integer_type(lhs.type) and not _is_bool_type(lhs.type)
-        ) or _is_index_type(lhs.type):
+        if (isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type)) or isinstance(lhs.type, IndexType):
             return arith.MaxUIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
 
     def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if isinstance(lhs.type, FloatType):
             return arith.MinimumFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
             return arith.MinSIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
 
     def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
-        if (
-            _is_integer_type(lhs.type) and not _is_bool_type(lhs.type)
-        ) or _is_index_type(lhs.type):
+        if (isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type)) or isinstance(lhs.type, IndexType):
             return arith.MinUIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
 
@@ -609,46 +607,3 @@ def _add_type_mapping(
             )
     type_mapping[name] = element_or_self_type
     block_arg_types.append(element_or_self_type)
-
-
-def _is_complex_type(t: Type) -> bool:
-    return ComplexType.isinstance(t)
-
-
-def _is_floating_point_type(t: Type) -> bool:
-    # TODO: Create a FloatType in the Python API and implement the switch
-    # there.
-    return (
-        F64Type.isinstance(t)
-        or F32Type.isinstance(t)
-        or F16Type.isinstance(t)
-        or BF16Type.isinstance(t)
-    )
-
-
-def _is_integer_type(t: Type) -> bool:
-    return IntegerType.isinstance(t)
-
-
-def _is_index_type(t: Type) -> bool:
-    return IndexType.isinstance(t)
-
-
-def _is_bool_type(t: Type) -> bool:
-    if not IntegerType.isinstance(t):
-        return False
-    return IntegerType(t).width == 1
-
-
-def _get_floating_point_width(t: Type) -> int:
-    # TODO: Create a FloatType in the Python API and implement the switch
-    # there.
-    if F64Type.isinstance(t):
-        return 64
-    if F32Type.isinstance(t):
-        return 32
-    if F16Type.isinstance(t):
-        return 16
-    if BF16Type.isinstance(t):
-        return 16
-    raise NotImplementedError(f"Unhandled floating point type switch {t}")
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 910a2356ca0e4..34f00a3292b79 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -8,15 +8,22 @@
 from ._memref_ops_gen import *
 from ._memref_ops_gen import _Dialect
 from ._ods_common import _dispatch_mixed_values, MixedValues
-from .arith import ConstantOp, _is_integer_like_type
-from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
+from ..ir import (
+    IndexType,
+    IntegerType,
+    MemRefType,
+    ShapedType,
+    StridedLayoutAttr,
+    Value,
+)
+from . import arith
 
 
 def _is_constant_int_like(i):
     return (
         isinstance(i, Value)
-        and isinstance(i.owner, ConstantOp)
-        and _is_integer_like_type(i.type)
+        and isinstance(i.owner, arith.ConstantOp)
+        and isinstance(i.type, (IntegerType, IndexType))
     )
 
 
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index c9af5e7b46db8..a4cfb30240231 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -42,10 +42,10 @@ def testFastMathFlags():
 def testArithValue():
     def _binary_op(lhs, rhs, op: str) -> "ArithValue":
         op = op.capitalize()
-        if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+        if isinstance(lhs.type, FloatType) and isinstance(rhs.type, FloatType):
             op += "F"
-        elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
-            lhs.type
+        elif isinstance(lhs.type, (IntegerType, IndexType)) and isinstance(
+            lhs.type, (IntegerType, IndexType)
         ):
             op += "I"
         else:
diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py
index 1747c66aa6639..6448a88dc1775 100644
--- a/mlir/test/python/ir/auto_location.py
+++ b/mlir/test/python/ir/auto_location.py
@@ -34,7 +34,7 @@ def testInferLocations():
         # Test nesting of loc_tracebacks().
         with loc_tracebacks():
             # fmt: off
-            # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))))
+            # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":45:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":90:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))))
             # fmt: on
             print(one.location)
 

>From dce82d9bd3236ddecee55f5f076e3e6c4609712b Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 19 Dec 2025 11:23:29 -0800
Subject: [PATCH 2/4] remove "isinstance" from core bindings

---
 mlir/include/mlir/Bindings/Python/IRCore.h    | 12 ------------
 mlir/lib/Bindings/Python/IRAffine.cpp         |  6 ------
 mlir/test/mlir-tblgen/op-python-bindings.td   |  2 +-
 mlir/test/python/ir/affine_expr.py            | 16 ++++++++--------
 mlir/test/python/ir/attributes.py             |  8 ++++----
 mlir/test/python/ir/builtin_types.py          | 10 +++++-----
 mlir/test/python/ir/value.py                  |  8 ++++----
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp |  2 +-
 8 files changed, 23 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 4ff5061b945aa..0df8a4840bde3 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -957,12 +957,6 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
     auto cls = ClassTy(m, DerivedTy::pyClassName);
     cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
             nanobind::arg("cast_from_type"));
-    cls.def_static(
-        "isinstance",
-        [](PyType &otherType) -> bool {
-          return DerivedTy::isaFunction(otherType);
-        },
-        nanobind::arg("other"));
     cls.def_prop_ro_static(
         "static_typeid",
         [](nanobind::object & /*class*/) {
@@ -1094,12 +1088,6 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
     }
     cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
             nanobind::arg("cast_from_attr"));
-    cls.def_static(
-        "isinstance",
-        [](PyAttribute &otherAttr) -> bool {
-          return DerivedTy::isaFunction(otherAttr);
-        },
-        nanobind::arg("other"));
     cls.def_prop_ro(
         "type",
         [](PyAttribute &attr) -> nanobind::typed<nanobind::object, PyType> {
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index ce235470bbdc7..b3d15ee59566b 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -118,12 +118,6 @@ class PyConcreteAffineExpr : public BaseTy {
   static void bind(nb::module_ &m) {
     auto cls = ClassTy(m, DerivedTy::pyClassName);
     cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
-    cls.def_static(
-        "isinstance",
-        [](PyAffineExpr &otherAffineExpr) -> bool {
-          return DerivedTy::isaFunction(otherAffineExpr);
-        },
-        nb::arg("other"));
     DerivedTy::bindDerived(cls);
   }
 
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index ff16ad8ca0cdd..929851724ba71 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -233,7 +233,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
   // CHECK:     _ods_result_type_source_attr = attributes["type"]
   // CHECK:     _ods_derived_result_type = (
   // CHECK:       _ods_ir.TypeAttr(_ods_result_type_source_attr).value
-  // CHECK:       if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
+  // CHECK:       if isinstance(_ods_result_type_source_attr, _ods_ir.TypeAttr) else
   // CHECK:       _ods_result_type_source_attr.type)
   // CHECK:     results = [_ods_derived_result_type] * 2
   let arguments = (ins TypeAttr:$type);
diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py
index c2a2ab3509ca6..82c509efdf8fc 100644
--- a/mlir/test/python/ir/affine_expr.py
+++ b/mlir/test/python/ir/affine_expr.py
@@ -354,21 +354,21 @@ def testIsInstance():
         mul = AffineMulExpr.get(d1, c2)
 
         # CHECK: True
-        print(AffineDimExpr.isinstance(d1))
+        print(isinstance(d1, AffineDimExpr))
         # CHECK: False
-        print(AffineConstantExpr.isinstance(d1))
+        print(isinstance(d1, AffineConstantExpr))
         # CHECK: True
-        print(AffineConstantExpr.isinstance(c2))
+        print(isinstance(c2, AffineConstantExpr))
         # CHECK: False
-        print(AffineMulExpr.isinstance(c2))
+        print(isinstance(c2, AffineMulExpr))
         # CHECK: True
-        print(AffineAddExpr.isinstance(add))
+        print(isinstance(add, AffineAddExpr))
         # CHECK: False
-        print(AffineMulExpr.isinstance(add))
+        print(isinstance(add, AffineMulExpr))
         # CHECK: True
-        print(AffineMulExpr.isinstance(mul))
+        print(isinstance(mul, AffineMulExpr))
         # CHECK: False
-        print(AffineAddExpr.isinstance(mul))
+        print(isinstance(mul, AffineAddExpr))
 
 
 # CHECK-LABEL: TEST: testCompose
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 2f3c4460d3f59..5ab671bd4d298 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -94,10 +94,10 @@ def testAttrIsInstance():
     with Context():
         a1 = Attribute.parse("42")
         a2 = Attribute.parse("[42]")
-        assert IntegerAttr.isinstance(a1)
-        assert not IntegerAttr.isinstance(a2)
-        assert not ArrayAttr.isinstance(a1)
-        assert ArrayAttr.isinstance(a2)
+        assert isinstance(a1, IntegerAttr)
+        assert not isinstance(a2, IntegerAttr)
+        assert not isinstance(a1, ArrayAttr)
+        assert isinstance(a2, ArrayAttr)
 
 
 # CHECK-LABEL: TEST: testAttrEqDoesNotRaise
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 54863253fc770..aa1665a4020fc 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -97,15 +97,15 @@ def testTypeIsInstance():
     t1 = Type.parse("i32", ctx)
     t2 = Type.parse("f32", ctx)
     # CHECK: True
-    print(IntegerType.isinstance(t1))
+    print(isinstance(t1, IntegerType))
     # CHECK: False
-    print(F32Type.isinstance(t1))
+    print(isinstance(t1, F32Type))
     # CHECK: False
-    print(FloatType.isinstance(t1))
+    print(isinstance(t1, FloatType))
     # CHECK: True
-    print(F32Type.isinstance(t2))
+    print(isinstance(t2, F32Type))
     # CHECK: True
-    print(FloatType.isinstance(t2))
+    print(isinstance(t2, FloatType))
 
 
 # CHECK-LABEL: TEST: testFloatTypeSubclasses
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 4a241afb8e89d..45efb880bab44 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -69,12 +69,12 @@ def testValueIsInstance():
         ctx,
     )
     func = module.body.operations[0]
-    assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
-    assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
+    assert isinstance(func.regions[0].blocks[0].arguments[0], BlockArgument)
+    assert not isinstance(func.regions[0].blocks[0].arguments[0], OpResult)
 
     op = func.regions[0].blocks[0].operations[0]
-    assert not BlockArgument.isinstance(op.results[0])
-    assert OpResult.isinstance(op.results[0])
+    assert not isinstance(op.results[0], BlockArgument)
+    assert isinstance(op.results[0], OpResult)
 
 
 # CHECK-LABEL: TEST: testValueHash
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 2c33f4efac3ac..6545559ff1b10 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -886,7 +886,7 @@ constexpr const char *firstAttrDerivedResultTypeTemplate =
   _ods_result_type_source_attr = attributes["{0}"]
   _ods_derived_result_type = (
     _ods_ir.TypeAttr(_ods_result_type_source_attr).value
-    if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
+    if isinstance(_ods_result_type_source_attr, _ods_ir.TypeAttr) else
     _ods_result_type_source_attr.type)
   results = [_ods_derived_result_type] * {1})Py";
 

>From debd17b98c8fcc80eb12c3f3bd0b0f4322147881 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 5 Jan 2026 12:09:56 -0800
Subject: [PATCH 3/4] fix after rebase (quant and pdl broken)

---
 mlir/include/mlir/Bindings/Python/IRCore.h    |   6 -
 .../dialects/linalg/opdsl/lang/emitter.py     |  14 ++-
 mlir/test/python/dialects/pdl_types.py        | 104 +++++++++---------
 mlir/test/python/dialects/quant.py            |  34 +++---
 4 files changed, 81 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 0df8a4840bde3..4930ce5ca6b8d 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1543,12 +1543,6 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
                           .c_str()));
     cls.def(nanobind::init<PyValue &>(), nanobind::keep_alive<0, 1>(),
             nanobind::arg("value"));
-    cls.def_static(
-        "isinstance",
-        [](PyValue &otherValue) -> bool {
-          return DerivedTy::isaFunction(otherValue);
-        },
-        nanobind::arg("other_value"));
     cls.def(
         MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
         [](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 126ebfca9e0bb..af90f3f8c4e3c 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -534,7 +534,9 @@ def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
         raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
 
     def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
-        if (isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type)) or isinstance(lhs.type, IndexType):
+        if (
+            isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type)
+        ) or isinstance(lhs.type, IndexType):
             return arith.MaxUIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
 
@@ -546,7 +548,9 @@ def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
         raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
 
     def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
-        if (isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type)) or isinstance(lhs.type, IndexType):
+        if (
+            isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type)
+        ) or isinstance(lhs.type, IndexType):
             return arith.MinUIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
 
@@ -607,3 +611,9 @@ def _add_type_mapping(
             )
     type_mapping[name] = element_or_self_type
     block_arg_types.append(element_or_self_type)
+
+
+def _is_bool_type(t: Type) -> bool:
+    if not isinstance(t, IntegerType):
+        return False
+    return t.width == 1
diff --git a/mlir/test/python/dialects/pdl_types.py b/mlir/test/python/dialects/pdl_types.py
index f75428d295c9c..58c9e74e95bff 100644
--- a/mlir/test/python/dialects/pdl_types.py
+++ b/mlir/test/python/dialects/pdl_types.py
@@ -17,17 +17,17 @@ def test_attribute_type():
         parsedType = Type.parse("!pdl.attribute")
         constructedType = pdl.AttributeType.get()
 
-        assert pdl.AttributeType.isinstance(parsedType)
-        assert not pdl.OperationType.isinstance(parsedType)
-        assert not pdl.RangeType.isinstance(parsedType)
-        assert not pdl.TypeType.isinstance(parsedType)
-        assert not pdl.ValueType.isinstance(parsedType)
-
-        assert pdl.AttributeType.isinstance(constructedType)
-        assert not pdl.OperationType.isinstance(constructedType)
-        assert not pdl.RangeType.isinstance(constructedType)
-        assert not pdl.TypeType.isinstance(constructedType)
-        assert not pdl.ValueType.isinstance(constructedType)
+        assert isinstance(parsedType, pdl.AttributeType)
+        assert not isinstance(parsedType, pdl.OperationType)
+        assert not isinstance(parsedType, pdl.RangeType)
+        assert not isinstance(parsedType, pdl.TypeType)
+        assert not isinstance(parsedType, pdl.ValueType)
+
+        assert isinstance(constructedType, pdl.AttributeType)
+        assert not isinstance(constructedType, pdl.OperationType)
+        assert not isinstance(constructedType, pdl.RangeType)
+        assert not isinstance(constructedType, pdl.TypeType)
+        assert not isinstance(constructedType, pdl.ValueType)
 
         assert parsedType == constructedType
 
@@ -44,17 +44,17 @@ def test_operation_type():
         parsedType = Type.parse("!pdl.operation")
         constructedType = pdl.OperationType.get()
 
-        assert not pdl.AttributeType.isinstance(parsedType)
-        assert pdl.OperationType.isinstance(parsedType)
-        assert not pdl.RangeType.isinstance(parsedType)
-        assert not pdl.TypeType.isinstance(parsedType)
-        assert not pdl.ValueType.isinstance(parsedType)
+        assert not isinstance(parsedType, pdl.AttributeType)
+        assert isinstance(parsedType, pdl.OperationType)
+        assert not isinstance(parsedType, pdl.RangeType)
+        assert not isinstance(parsedType, pdl.TypeType)
+        assert not isinstance(parsedType, pdl.ValueType)
 
-        assert not pdl.AttributeType.isinstance(constructedType)
-        assert pdl.OperationType.isinstance(constructedType)
-        assert not pdl.RangeType.isinstance(constructedType)
-        assert not pdl.TypeType.isinstance(constructedType)
-        assert not pdl.ValueType.isinstance(constructedType)
+        assert not isinstance(constructedType, pdl.AttributeType)
+        assert isinstance(constructedType, pdl.OperationType)
+        assert not isinstance(constructedType, pdl.RangeType)
+        assert not isinstance(constructedType, pdl.TypeType)
+        assert not isinstance(constructedType, pdl.ValueType)
 
         assert parsedType == constructedType
 
@@ -73,17 +73,17 @@ def test_range_type():
         constructedType = pdl.RangeType.get(typeType)
         elementType = constructedType.element_type
 
-        assert not pdl.AttributeType.isinstance(parsedType)
-        assert not pdl.OperationType.isinstance(parsedType)
-        assert pdl.RangeType.isinstance(parsedType)
-        assert not pdl.TypeType.isinstance(parsedType)
-        assert not pdl.ValueType.isinstance(parsedType)
+        assert not isinstance(parsedType, pdl.AttributeType)
+        assert not isinstance(parsedType, pdl.OperationType)
+        assert isinstance(parsedType, pdl.RangeType)
+        assert not isinstance(parsedType, pdl.TypeType)
+        assert not isinstance(parsedType, pdl.ValueType)
 
-        assert not pdl.AttributeType.isinstance(constructedType)
-        assert not pdl.OperationType.isinstance(constructedType)
-        assert pdl.RangeType.isinstance(constructedType)
-        assert not pdl.TypeType.isinstance(constructedType)
-        assert not pdl.ValueType.isinstance(constructedType)
+        assert not isinstance(constructedType, pdl.AttributeType)
+        assert not isinstance(constructedType, pdl.OperationType)
+        assert isinstance(constructedType, pdl.RangeType)
+        assert not isinstance(constructedType, pdl.TypeType)
+        assert not isinstance(constructedType, pdl.ValueType)
 
         assert parsedType == constructedType
         assert elementType == typeType
@@ -103,17 +103,17 @@ def test_type_type():
         parsedType = Type.parse("!pdl.type")
         constructedType = pdl.TypeType.get()
 
-        assert not pdl.AttributeType.isinstance(parsedType)
-        assert not pdl.OperationType.isinstance(parsedType)
-        assert not pdl.RangeType.isinstance(parsedType)
-        assert pdl.TypeType.isinstance(parsedType)
-        assert not pdl.ValueType.isinstance(parsedType)
+        assert not isinstance(parsedType, pdl.AttributeType)
+        assert not isinstance(parsedType, pdl.OperationType)
+        assert not isinstance(parsedType, pdl.RangeType)
+        assert isinstance(parsedType, pdl.TypeType)
+        assert not isinstance(parsedType, pdl.ValueType)
 
-        assert not pdl.AttributeType.isinstance(constructedType)
-        assert not pdl.OperationType.isinstance(constructedType)
-        assert not pdl.RangeType.isinstance(constructedType)
-        assert pdl.TypeType.isinstance(constructedType)
-        assert not pdl.ValueType.isinstance(constructedType)
+        assert not isinstance(constructedType, pdl.AttributeType)
+        assert not isinstance(constructedType, pdl.OperationType)
+        assert not isinstance(constructedType, pdl.RangeType)
+        assert isinstance(constructedType, pdl.TypeType)
+        assert not isinstance(constructedType, pdl.ValueType)
 
         assert parsedType == constructedType
 
@@ -130,17 +130,17 @@ def test_value_type():
         parsedType = Type.parse("!pdl.value")
         constructedType = pdl.ValueType.get()
 
-        assert not pdl.AttributeType.isinstance(parsedType)
-        assert not pdl.OperationType.isinstance(parsedType)
-        assert not pdl.RangeType.isinstance(parsedType)
-        assert not pdl.TypeType.isinstance(parsedType)
-        assert pdl.ValueType.isinstance(parsedType)
-
-        assert not pdl.AttributeType.isinstance(constructedType)
-        assert not pdl.OperationType.isinstance(constructedType)
-        assert not pdl.RangeType.isinstance(constructedType)
-        assert not pdl.TypeType.isinstance(constructedType)
-        assert pdl.ValueType.isinstance(constructedType)
+        assert not isinstance(parsedType, pdl.AttributeType)
+        assert not isinstance(parsedType, pdl.OperationType)
+        assert not isinstance(parsedType, pdl.RangeType)
+        assert not isinstance(parsedType, pdl.TypeType)
+        assert isinstance(parsedType, pdl.ValueType)
+
+        assert not isinstance(constructedType, pdl.AttributeType)
+        assert not isinstance(constructedType, pdl.OperationType)
+        assert not isinstance(constructedType, pdl.RangeType)
+        assert not isinstance(constructedType, pdl.TypeType)
+        assert isinstance(constructedType, pdl.ValueType)
 
         assert parsedType == constructedType
 
diff --git a/mlir/test/python/dialects/quant.py b/mlir/test/python/dialects/quant.py
index 57c528da7b9eb..f40d1224c9e4e 100644
--- a/mlir/test/python/dialects/quant.py
+++ b/mlir/test/python/dialects/quant.py
@@ -24,23 +24,23 @@ def test_type_hierarchy():
         )
         calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
 
-        assert not quant.QuantizedType.isinstance(i8)
-        assert quant.QuantizedType.isinstance(any)
-        assert quant.QuantizedType.isinstance(uniform)
-        assert quant.QuantizedType.isinstance(per_axis)
-        assert quant.QuantizedType.isinstance(sub_channel)
-        assert quant.QuantizedType.isinstance(calibrated)
-
-        assert quant.AnyQuantizedType.isinstance(any)
-        assert quant.UniformQuantizedType.isinstance(uniform)
-        assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
-        assert quant.UniformQuantizedSubChannelType.isinstance(sub_channel)
-        assert quant.CalibratedQuantizedType.isinstance(calibrated)
-
-        assert not quant.AnyQuantizedType.isinstance(uniform)
-        assert not quant.UniformQuantizedType.isinstance(per_axis)
-        assert not quant.UniformQuantizedType.isinstance(sub_channel)
-        assert not quant.UniformQuantizedPerAxisType.isinstance(sub_channel)
+        assert not isinstance(i8, quant.QuantizedType)
+        assert isinstance(any, quant.QuantizedType)
+        assert isinstance(uniform, quant.QuantizedType)
+        assert isinstance(per_axis, quant.QuantizedType)
+        assert isinstance(sub_channel, quant.QuantizedType)
+        assert isinstance(calibrated, quant.QuantizedType)
+
+        assert isinstance(any, quant.AnyQuantizedType)
+        assert isinstance(uniform, quant.UniformQuantizedType)
+        assert isinstance(per_axis, quant.UniformQuantizedPerAxisType)
+        assert isinstance(sub_channel, quant.UniformQuantizedSubChannelType)
+        assert isinstance(calibrated, quant.CalibratedQuantizedType)
+
+        assert not isinstance(uniform, quant.AnyQuantizedType)
+        assert not isinstance(per_axis, quant.UniformQuantizedType)
+        assert not isinstance(sub_channel, quant.UniformQuantizedType)
+        assert not isinstance(sub_channel, quant.UniformQuantizedPerAxisType)
 
 
 # CHECK-LABEL: TEST: test_any_quantized_type

>From bced1dc82bb5ebb2c7d3430da13d30eb4e0fd5ca Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 5 Jan 2026 12:57:47 -0800
Subject: [PATCH 4/4] add typeid gettors to pdl and quant

---
 mlir/include/mlir-c/Dialect/PDL.h         | 10 ++++++++++
 mlir/include/mlir-c/Dialect/Quant.h       | 10 ++++++++++
 mlir/lib/Bindings/Python/DialectPDL.cpp   | 10 ++++++++++
 mlir/lib/Bindings/Python/DialectQuant.cpp | 10 ++++++++++
 mlir/lib/CAPI/Dialect/PDL.cpp             | 20 ++++++++++++++++++++
 mlir/lib/CAPI/Dialect/Quant.cpp           | 20 ++++++++++++++++++++
 6 files changed, 80 insertions(+)

diff --git a/mlir/include/mlir-c/Dialect/PDL.h b/mlir/include/mlir-c/Dialect/PDL.h
index 6ad2e2da62d87..d04f69e391b13 100644
--- a/mlir/include/mlir-c/Dialect/PDL.h
+++ b/mlir/include/mlir-c/Dialect/PDL.h
@@ -30,6 +30,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLType(MlirType type);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLAttributeType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirPDLAttributeTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx);
 
 //===---------------------------------------------------------------------===//
@@ -38,6 +40,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLOperationType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirPDLOperationTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx);
 
 //===---------------------------------------------------------------------===//
@@ -46,6 +50,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirPDLRangeTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType);
 
 MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
@@ -56,6 +62,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLTypeType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirPDLTypeTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx);
 
 //===---------------------------------------------------------------------===//
@@ -64,6 +72,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLValueType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirPDLValueTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirPDLValueTypeGet(MlirContext ctx);
 
 #ifdef __cplusplus
diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h
index dc0989e53344e..f961c01d5dc2a 100644
--- a/mlir/include/mlir-c/Dialect/Quant.h
+++ b/mlir/include/mlir-c/Dialect/Quant.h
@@ -103,6 +103,8 @@ mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate);
 /// Returns `true` if the given type is an AnyQuantizedType.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAAnyQuantizedType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirAnyQuantizedTypeGetTypeID(void);
+
 /// Creates an instance of AnyQuantizedType with the given parameters in the
 /// same context as `storageType` and returns it. The instance is owned by the
 /// context.
@@ -119,6 +121,8 @@ MLIR_CAPI_EXPORTED MlirType mlirAnyQuantizedTypeGet(unsigned flags,
 /// Returns `true` if the given type is a UniformQuantizedType.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedTypeGetTypeID(void);
+
 /// Creates an instance of UniformQuantizedType with the given parameters in the
 /// same context as `storageType` and returns it. The instance is owned by the
 /// context.
@@ -142,6 +146,8 @@ MLIR_CAPI_EXPORTED bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type);
 /// Returns `true` if the given type is a UniformQuantizedPerAxisType.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedPerAxisTypeGetTypeID(void);
+
 /// Creates an instance of UniformQuantizedPerAxisType with the given parameters
 /// in the same context as `storageType` and returns it. `scales` and
 /// `zeroPoints` point to `nDims` number of elements. The instance is owned
@@ -180,6 +186,8 @@ mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);
 MLIR_CAPI_EXPORTED bool
 mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedSubChannelTypeGetTypeID(void);
+
 /// Creates a UniformQuantizedSubChannelType with the given parameters.
 ///
 /// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
@@ -220,6 +228,8 @@ mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);
 /// Returns `true` if the given type is a CalibratedQuantizedType.
 MLIR_CAPI_EXPORTED bool mlirTypeIsACalibratedQuantizedType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void);
+
 /// Creates an instance of CalibratedQuantizedType with the given parameters
 /// in the same context as `expressedType` and returns it. The instance is owned
 /// by the context.
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index ac72734ea5c21..17d0a83127018 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -39,6 +39,8 @@ struct PDLType : PyConcreteType<PDLType> {
 
 struct AttributeType : PyConcreteType<AttributeType> {
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLAttributeType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirPDLAttributeTypeGetTypeID;
   static constexpr const char *pyClassName = "AttributeType";
   using Base::Base;
 
@@ -60,6 +62,8 @@ struct AttributeType : PyConcreteType<AttributeType> {
 
 struct OperationType : PyConcreteType<OperationType> {
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLOperationType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirPDLOperationTypeGetTypeID;
   static constexpr const char *pyClassName = "OperationType";
   using Base::Base;
 
@@ -81,6 +85,8 @@ struct OperationType : PyConcreteType<OperationType> {
 
 struct RangeType : PyConcreteType<RangeType> {
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLRangeType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirPDLRangeTypeGetTypeID;
   static constexpr const char *pyClassName = "RangeType";
   using Base::Base;
 
@@ -109,6 +115,8 @@ struct RangeType : PyConcreteType<RangeType> {
 
 struct TypeType : PyConcreteType<TypeType> {
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLTypeType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirPDLTypeTypeGetTypeID;
   static constexpr const char *pyClassName = "TypeType";
   using Base::Base;
 
@@ -130,6 +138,8 @@ struct TypeType : PyConcreteType<TypeType> {
 
 struct ValueType : PyConcreteType<ValueType> {
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLValueType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirPDLValueTypeGetTypeID;
   static constexpr const char *pyClassName = "ValueType";
   using Base::Base;
 
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 9c6a15c97134d..3a6a91f3058ab 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -192,6 +192,8 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
 
 struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAnyQuantizedType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAnyQuantizedTypeGetTypeID;
   static constexpr const char *pyClassName = "AnyQuantizedType";
   using Base::Base;
 
@@ -221,6 +223,8 @@ struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
 struct UniformQuantizedType
     : PyConcreteType<UniformQuantizedType, QuantizedType> {
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUniformQuantizedTypeGetTypeID;
   static constexpr const char *pyClassName = "UniformQuantizedType";
   using Base::Base;
 
@@ -273,6 +277,8 @@ struct UniformQuantizedPerAxisType
     : PyConcreteType<UniformQuantizedPerAxisType, QuantizedType> {
   static constexpr IsAFunctionTy isaFunction =
       mlirTypeIsAUniformQuantizedPerAxisType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUniformQuantizedPerAxisTypeGetTypeID;
   static constexpr const char *pyClassName = "UniformQuantizedPerAxisType";
   using Base::Base;
 
@@ -357,6 +363,8 @@ struct UniformQuantizedSubChannelType
     : PyConcreteType<UniformQuantizedSubChannelType, QuantizedType> {
   static constexpr IsAFunctionTy isaFunction =
       mlirTypeIsAUniformQuantizedSubChannelType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUniformQuantizedSubChannelTypeGetTypeID;
   static constexpr const char *pyClassName = "UniformQuantizedSubChannelType";
   using Base::Base;
 
@@ -448,6 +456,8 @@ struct CalibratedQuantizedType
     : PyConcreteType<CalibratedQuantizedType, QuantizedType> {
   static constexpr IsAFunctionTy isaFunction =
       mlirTypeIsACalibratedQuantizedType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirCalibratedQuantizedTypeGetTypeID;
   static constexpr const char *pyClassName = "CalibratedQuantizedType";
   using Base::Base;
 
diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp
index bd8b13c6516e2..88cd6056480f1 100644
--- a/mlir/lib/CAPI/Dialect/PDL.cpp
+++ b/mlir/lib/CAPI/Dialect/PDL.cpp
@@ -32,6 +32,10 @@ bool mlirTypeIsAPDLAttributeType(MlirType type) {
   return isa<pdl::AttributeType>(unwrap(type));
 }
 
+MlirTypeID mlirPDLAttributeTypeGetTypeID(void) {
+  return wrap(pdl::AttributeType::getTypeID());
+}
+
 MlirType mlirPDLAttributeTypeGet(MlirContext ctx) {
   return wrap(pdl::AttributeType::get(unwrap(ctx)));
 }
@@ -44,6 +48,10 @@ bool mlirTypeIsAPDLOperationType(MlirType type) {
   return isa<pdl::OperationType>(unwrap(type));
 }
 
+MlirTypeID mlirPDLOperationTypeGetTypeID(void) {
+  return wrap(pdl::OperationType::getTypeID());
+}
+
 MlirType mlirPDLOperationTypeGet(MlirContext ctx) {
   return wrap(pdl::OperationType::get(unwrap(ctx)));
 }
@@ -56,6 +64,10 @@ bool mlirTypeIsAPDLRangeType(MlirType type) {
   return isa<pdl::RangeType>(unwrap(type));
 }
 
+MlirTypeID mlirPDLRangeTypeGetTypeID(void) {
+  return wrap(pdl::RangeType::getTypeID());
+}
+
 MlirType mlirPDLRangeTypeGet(MlirType elementType) {
   return wrap(pdl::RangeType::get(unwrap(elementType)));
 }
@@ -72,6 +84,10 @@ bool mlirTypeIsAPDLTypeType(MlirType type) {
   return isa<pdl::TypeType>(unwrap(type));
 }
 
+MlirTypeID mlirPDLTypeTypeGetTypeID(void) {
+  return wrap(pdl::TypeType::getTypeID());
+}
+
 MlirType mlirPDLTypeTypeGet(MlirContext ctx) {
   return wrap(pdl::TypeType::get(unwrap(ctx)));
 }
@@ -84,6 +100,10 @@ bool mlirTypeIsAPDLValueType(MlirType type) {
   return isa<pdl::ValueType>(unwrap(type));
 }
 
+MlirTypeID mlirPDLValueTypeGetTypeID(void) {
+  return wrap(pdl::ValueType::getTypeID());
+}
+
 MlirType mlirPDLValueTypeGet(MlirContext ctx) {
   return wrap(pdl::ValueType::get(unwrap(ctx)));
 }
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index 01a6a948f1dc0..840051caab852 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -113,6 +113,10 @@ bool mlirTypeIsAAnyQuantizedType(MlirType type) {
   return isa<quant::AnyQuantizedType>(unwrap(type));
 }
 
+MlirTypeID mlirAnyQuantizedTypeGetTypeID(void) {
+  return wrap(quant::AnyQuantizedType::getTypeID());
+}
+
 MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
                                  MlirType expressedType, int64_t storageTypeMin,
                                  int64_t storageTypeMax) {
@@ -129,6 +133,10 @@ bool mlirTypeIsAUniformQuantizedType(MlirType type) {
   return isa<quant::UniformQuantizedType>(unwrap(type));
 }
 
+MlirTypeID mlirUniformQuantizedTypeGetTypeID(void) {
+  return wrap(quant::UniformQuantizedType::getTypeID());
+}
+
 MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
                                      MlirType expressedType, double scale,
                                      int64_t zeroPoint, int64_t storageTypeMin,
@@ -158,6 +166,10 @@ bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
   return isa<quant::UniformQuantizedPerAxisType>(unwrap(type));
 }
 
+MlirTypeID mlirUniformQuantizedPerAxisTypeGetTypeID(void) {
+  return wrap(quant::UniformQuantizedPerAxisType::getTypeID());
+}
+
 MlirType mlirUniformQuantizedPerAxisTypeGet(
     unsigned flags, MlirType storageType, MlirType expressedType,
     intptr_t nDims, double *scales, int64_t *zeroPoints,
@@ -203,6 +215,10 @@ bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) {
   return isa<quant::UniformQuantizedSubChannelType>(unwrap(type));
 }
 
+MlirTypeID mlirUniformQuantizedSubChannelTypeGetTypeID(void) {
+  return wrap(quant::UniformQuantizedSubChannelType::getTypeID());
+}
+
 MlirType mlirUniformQuantizedSubChannelTypeGet(
     unsigned flags, MlirType storageType, MlirType expressedType,
     MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims,
@@ -258,6 +274,10 @@ bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
   return isa<quant::CalibratedQuantizedType>(unwrap(type));
 }
 
+MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void) {
+  return wrap(quant::CalibratedQuantizedType::getTypeID());
+}
+
 MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
                                         double max) {
   return wrap(



More information about the Mlir-commits mailing list