[Mlir-commits] [mlir] [mlir][py] NFC: remove exception-based isa from linalg module (PR #92556)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri May 17 07:25:30 PDT 2024


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/92556

When this code was written, we didn't have proper isinstance support for operation classes in Python. Now we do, so there is no reason to keep the expensive exception-based flow.

>From 33df710b4a8e9805dae994d9e74de7143b43c992 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <ftynse at gmail.com>
Date: Fri, 17 May 2024 16:24:08 +0200
Subject: [PATCH] [mlir][py] NFC: remove exception-based isa from linalg module

When this code was written, we didn't have proper isinstance support for
operation classes in Python. Now we do, so there is no reason to keep the
expensive exception-based flow.
---
 mlir/python/mlir/dialects/linalg/__init__.py           |  5 ++---
 mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | 10 +---------
 2 files changed, 3 insertions(+), 12 deletions(-)

diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 6e4cb1bd62671..8fb1227ee80ff 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -55,7 +55,6 @@
 #     TODO: guard against surprises and fail create Runtime Custom Ops with
 #     the same name as existing Core Named Ops.
 from .opdsl.ops.core_named_ops import *
-from .opdsl.lang.emitter import isa
 
 from ...ir import *
 from .._ods_common import get_op_result_or_value as _get_op_result_or_value
@@ -71,7 +70,7 @@ def transpose(
     if len(outs) > 1:
         raise ValueError(f"{outs=} must have length 1.")
     init = _get_op_result_or_value(outs[0])
-    result_types = [init.type] if isa(RankedTensorType, init.type) else []
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
 
     op = TransposeOp(
         result=result_types,
@@ -93,7 +92,7 @@ def broadcast(
     if len(outs) > 1:
         raise ValueError(f"{outs=} must have length 1.")
     init = _get_op_result_or_value(outs[0])
-    result_types = [init.type] if isa(RankedTensorType, init.type) else []
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
 
     op = BroadcastOp(
         result=result_types,
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 845b533db52a9..254458a978828 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -31,14 +31,6 @@
 ValueList = Union[Sequence[Value], OpResultList]
 
 
-def isa(cls: Type, ty: Type):
-    try:
-        cls(ty)
-        return True
-    except ValueError:
-        return False
-
-
 def prepare_common_structured_op(
     op_config: LinalgStructuredOpConfig,
     *ins: Value,
@@ -127,7 +119,7 @@ def prepare_common_structured_op(
         op_config, in_arg_defs, ins, out_arg_defs, outs
     )
 
-    result_types = [t for t in out_types if isa(RankedTensorType, t)]
+    result_types = [t for t in out_types if isinstance(t, RankedTensorType)]
 
     # Initialize the type dictionary with the predefined types.
     type_mapping = dict()  # type: Dict[str, Type]



More information about the Mlir-commits mailing list