[Mlir-commits] [mlir] [mlir][Python] use canonical Python `isinstance` instead of `Type.isinstance` (PR #172892)
Maksim Levental
llvmlistbot at llvm.org
Thu Dec 18 11:26:40 PST 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/172892
We've been able to do `isinstance(x, Type)` for a quite a while now (since https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb) so remove the special-casing in some places (and therefore support various `fp8`, `fp6`, `fp4` types). Note we could consider removing `Type.ininstance` entirely...
>From 18cc4488f39915c1c7d06718bc37f76bbcb97726 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 18 Dec 2025 11:22:57 -0800
Subject: [PATCH] [mlir][Python] use canonical Python isinstance instead of
Type.isinstance
---
mlir/python/mlir/dialects/arith.py | 30 +-----
.../dialects/linalg/opdsl/lang/emitter.py | 101 ++++++------------
mlir/python/mlir/dialects/memref.py | 15 ++-
mlir/test/python/dialects/arith_dialect.py | 6 +-
4 files changed, 51 insertions(+), 101 deletions(-)
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 88e8502a29eae..59a343435f7f0 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -21,39 +21,17 @@
raise RuntimeError("Error loading imports from extension module") from e
-def _isa(obj: Any, cls: type):
- try:
- cls(obj)
- except ValueError:
- return False
- return True
-
-
-def _is_any_of(obj: Any, classes: List[type]):
- return any(_isa(obj, cls) for cls in classes)
-
-
-def _is_integer_like_type(type: Type):
- return _is_any_of(type, [IntegerType, IndexType])
-
-
-def _is_float_type(type: Type):
- return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
@overload
- def __init__(self, value: Attribute, *, loc=None, ip=None):
- ...
+ def __init__(self, value: Attribute, *, loc=None, ip=None): ...
@overload
def __init__(
self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
- ):
- ...
+ ): ...
def __init__(self, result, value, *, loc=None, ip=None):
if value is None:
@@ -96,9 +74,9 @@ def value(self):
@property
def literal_value(self) -> Union[int, float]:
- if _is_integer_like_type(self.type):
+ if isinstance(self.type, (IntegerType, IndexType)):
return IntegerAttr(self.value).value
- elif _is_float_type(self.type):
+ elif isinstance(self.type, FloatType):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..a338643bad54d 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -412,9 +412,9 @@ def _cast(
)
if operand.type == to_type:
return operand
- if _is_integer_type(to_type):
+ if isinstance(to_type, IntegerType):
return self._cast_to_integer(to_type, operand, is_unsigned_cast)
- elif _is_floating_point_type(to_type):
+ elif isinstance(to_type, FloatType):
return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
def _cast_to_integer(
@@ -422,11 +422,11 @@ def _cast_to_integer(
) -> Value:
to_width = IntegerType(to_type).width
operand_type = operand.type
- if _is_floating_point_type(operand_type):
+ if isinstance(operand_type, FloatType):
if is_unsigned_cast:
return arith.FPToUIOp(to_type, operand).result
return arith.FPToSIOp(to_type, operand).result
- if _is_index_type(operand_type):
+ if isinstance(operand_type, IndexType):
return arith.IndexCastOp(to_type, operand).result
# Assume integer.
from_width = IntegerType(operand_type).width
@@ -444,13 +444,15 @@ def _cast_to_floating_point(
self, to_type: Type, operand: Value, is_unsigned_cast: bool
) -> Value:
operand_type = operand.type
- if _is_integer_type(operand_type):
+ if isinstance(operand_type, IntegerType):
if is_unsigned_cast:
return arith.UIToFPOp(to_type, operand).result
return arith.SIToFPOp(to_type, operand).result
# Assume FloatType.
- to_width = _get_floating_point_width(to_type)
- from_width = _get_floating_point_width(operand_type)
+ assert isinstance(to_type, FloatType)
+ assert isinstance(operand_type, FloatType)
+ to_width = to_type.width
+ from_width = operand_type.width
if to_width > from_width:
return arith.ExtFOp(to_type, operand).result
elif to_width < from_width:
@@ -466,89 +468,89 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
return self._cast(type_var_name, operand, True)
def _unary_exp(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if isinstance(x.type, FloatType):
return math.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")
def _unary_log(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if isinstance(x.type, FloatType):
return math.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")
def _unary_abs(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if isinstance(x.type, FloatType):
return math.AbsFOp(x).result
raise NotImplementedError("Unsupported 'abs' operand: {x}")
def _unary_ceil(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if isinstance(x.type, FloatType):
return math.CeilOp(x).result
raise NotImplementedError("Unsupported 'ceil' operand: {x}")
def _unary_floor(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if isinstance(x.type, FloatType):
return math.FloorOp(x).result
raise NotImplementedError("Unsupported 'floor' operand: {x}")
def _unary_negf(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if isinstance(x.type, FloatType):
return arith.NegFOp(x).result
- if _is_complex_type(x.type):
+ if isinstance(x.type, ComplexType):
return complex.NegOp(x).result
raise NotImplementedError("Unsupported 'negf' operand: {x}")
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if isinstance(lhs.type, FloatType):
return arith.AddFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.AddIOp(lhs, rhs).result
- if _is_complex_type(lhs.type):
+ if isinstance(lhs.type, ComplexType):
return complex.AddOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if isinstance(lhs.type, FloatType):
return arith.SubFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.SubIOp(lhs, rhs).result
- if _is_complex_type(lhs.type):
+ if isinstance(lhs.type, ComplexType):
return complex.SubOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if isinstance(lhs.type, FloatType):
return arith.MulFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.MulIOp(lhs, rhs).result
- if _is_complex_type(lhs.type):
+ if isinstance(lhs.type, ComplexType):
return complex.MulOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if isinstance(lhs.type, FloatType):
return arith.MaximumFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.MaxSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if isinstance(lhs.type, FloatType):
return arith.MaximumFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if isinstance(lhs.type, FloatType):
return arith.MinimumFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.MinSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if isinstance(lhs.type, FloatType):
return arith.MinimumFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.MinUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
@@ -609,40 +611,3 @@ def _add_type_mapping(
)
type_mapping[name] = element_or_self_type
block_arg_types.append(element_or_self_type)
-
-
-def _is_complex_type(t: Type) -> bool:
- return ComplexType.isinstance(t)
-
-
-def _is_floating_point_type(t: Type) -> bool:
- # TODO: Create a FloatType in the Python API and implement the switch
- # there.
- return (
- F64Type.isinstance(t)
- or F32Type.isinstance(t)
- or F16Type.isinstance(t)
- or BF16Type.isinstance(t)
- )
-
-
-def _is_integer_type(t: Type) -> bool:
- return IntegerType.isinstance(t)
-
-
-def _is_index_type(t: Type) -> bool:
- return IndexType.isinstance(t)
-
-
-def _get_floating_point_width(t: Type) -> int:
- # TODO: Create a FloatType in the Python API and implement the switch
- # there.
- if F64Type.isinstance(t):
- return 64
- if F32Type.isinstance(t):
- return 32
- if F16Type.isinstance(t):
- return 16
- if BF16Type.isinstance(t):
- return 16
- raise NotImplementedError(f"Unhandled floating point type switch {t}")
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index c80a1b1a89358..fa60fe69a8bb8 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -7,15 +7,22 @@
from ._memref_ops_gen import *
from ._ods_common import _dispatch_mixed_values, MixedValues
-from .arith import ConstantOp, _is_integer_like_type
-from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
+from ..ir import (
+ Value,
+ MemRefType,
+ StridedLayoutAttr,
+ ShapedType,
+ IntegerType,
+ IndexType,
+)
+from . import arith
def _is_constant_int_like(i):
return (
isinstance(i, Value)
- and isinstance(i.owner, ConstantOp)
- and _is_integer_like_type(i.type)
+ and isinstance(i.owner, arith.ConstantOp)
+ and isinstance(i.type, (IntegerType, IndexType))
)
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index c9af5e7b46db8..a4cfb30240231 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -42,10 +42,10 @@ def testFastMathFlags():
def testArithValue():
def _binary_op(lhs, rhs, op: str) -> "ArithValue":
op = op.capitalize()
- if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+ if isinstance(lhs.type, FloatType) and isinstance(rhs.type, FloatType):
op += "F"
- elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
- lhs.type
+ elif isinstance(lhs.type, (IntegerType, IndexType)) and isinstance(
+ lhs.type, (IntegerType, IndexType)
):
op += "I"
else:
More information about the Mlir-commits
mailing list