[Mlir-commits] [mlir] [mlir][linalg][transform][python] Drop _get_op_result... from mix-ins. (PR #65726)

Ingo Müller llvmlistbot at llvm.org
Fri Sep 8 01:50:54 PDT 2023


https://github.com/ingomueller-net created https://github.com/llvm/llvm-project/pull/65726:

`_get_op_result_or_value` was used in mix-ins to unify the handling of op results and values. However, that function is now called in the generated constructors, such that doing so in the mix-ins is not necessary anymore.

>From 00c5c8153aeac74006c9ebc2b6d051d0cbd7fb79 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 7 Sep 2023 14:36:53 +0000
Subject: [PATCH] [mlir][linalg][transform][python] Drop _get_op_result... from
 mix-ins.

`_get_op_result_or_value` was used in mix-ins to unify the handling of
op results and values. However, that function is now called in the
generated constructors, such that doing so in the mix-ins is not
necessary anymore.
---
 .../dialects/_structured_transform_ops_ext.py | 47 +++++++------------
 1 file changed, 18 insertions(+), 29 deletions(-)

diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index f368e56f9818915..212fbc5badcbce8 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -4,7 +4,6 @@
 
 try:
     from ..ir import *
-    from ._ods_common import get_op_result_or_value as _get_op_result_or_value
     from ..dialects import pdl, transform
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
@@ -101,7 +100,7 @@ def _dispatch_mixed_values(
                 static_values.append(size)
             else:
                 static_values.append(ShapedType.get_dynamic_size())
-                dynamic_values.append(_get_op_result_or_value(size))
+                dynamic_values.append(size)
         static_values = DenseI64ArrayAttr.get(static_values)
 
     return (dynamic_values, packed_values, static_values)
@@ -204,9 +203,7 @@ class DecomposeOp:
     """Specialization for DecomposeOp class."""
 
     def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-        super().__init__(
-            pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
-        )
+        super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
 
 
 class FuseIntoContainingOp:
@@ -277,9 +274,7 @@ class GeneralizeOp:
     """Specialization for GeneralizeOp class."""
 
     def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-        super().__init__(
-            pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
-        )
+        super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
 
 
 class InterchangeOp:
@@ -296,7 +291,7 @@ def __init__(
         pdl_operation_type = pdl.OperationType.get()
         super().__init__(
             pdl_operation_type,
-            _get_op_result_or_value(target),
+            target,
             iterator_interchange=iterator_interchange,
             loc=loc,
             ip=ip,
@@ -415,7 +410,7 @@ def match_op_names(
         loc=None,
         ip=None,
     ):
-       ...
+        ...
 
     @overload
     @classmethod
@@ -428,7 +423,7 @@ def match_op_names(
         loc=None,
         ip=None,
     ):
-       ...
+        ...
 
     @classmethod
     def match_op_names(
@@ -441,20 +436,20 @@ def match_op_names(
         ip=None,
     ):
         if isinstance(result_type_or_target, Type):
-           result_type = result_type_or_target
-           target = target_or_names
-           names = names_or_none
+            result_type = result_type_or_target
+            target = target_or_names
+            names = names_or_none
         else:
-           result_type = transform.AnyOpType.get()
-           target = result_type_or_target
-           names = target_or_names
+            result_type = transform.AnyOpType.get()
+            target = result_type_or_target
+            names = target_or_names
 
         if isinstance(names, str):
-           names = [names]
+            names = [names]
 
         return cls(
             result_type,
-            _get_op_result_or_value(target),
+            target,
             ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
             loc=loc,
             ip=ip,
@@ -479,7 +474,7 @@ def __init__(
             result_type,
             result_type,
             result_type,
-            _get_op_result_or_value(target),
+            target,
             dimension=dimension,
             target_size=target_size,
             divisor=divisor,
@@ -530,9 +525,7 @@ class ScalarizeOp:
 
     def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
         pdl_operation_type = pdl.OperationType.get()
-        super().__init__(
-            pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
-        )
+        super().__init__(pdl_operation_type, target, loc=loc, ip=ip)
 
 
 class SplitOp:
@@ -552,9 +545,7 @@ def __init__(
             dynamic_split_point = None
         else:
             static_split_point = ShapedType.get_dynamic_size()
-            dynamic_split_point = _get_op_result_or_value(split_point)
-
-        target = _get_op_result_or_value(target)
+            dynamic_split_point = split_point
 
         super().__init__(
             target.type,
@@ -626,8 +617,6 @@ def __init__(
             )
             target = target_or_none
 
-        target = _get_op_result_or_value(target)
-
         super().__init__(
             target.type,
             loop_types,
@@ -750,7 +739,7 @@ def __init__(
         pdl_operation_type = pdl.OperationType.get()
         super().__init__(
             pdl_operation_type,
-            _get_op_result_or_value(target),
+            target,
             disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
             disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
             vectorize_nd_extract=vectorize_nd_extract,



More information about the Mlir-commits mailing list