[Mlir-commits] [mlir] [mlir][linalg][transform][python] Drop _get_op_result... from mix-ins. (PR #65726)
Ingo Müller
llvmlistbot at llvm.org
Fri Sep 8 01:50:54 PDT 2023
https://github.com/ingomueller-net created https://github.com/llvm/llvm-project/pull/65726:
`_get_op_result_or_value` was used in mix-ins to unify the handling of op results and values. However, that function is now called in the generated constructors, such that doing so in the mix-ins is not necessary anymore.
>From 00c5c8153aeac74006c9ebc2b6d051d0cbd7fb79 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 7 Sep 2023 14:36:53 +0000
Subject: [PATCH] [mlir][linalg][transform][python] Drop _get_op_result... from
mix-ins.
`_get_op_result_or_value` was used in mix-ins to unify the handling of
op results and values. However, that function is now called in the
generated constructors, such that doing so in the mix-ins is not
necessary anymore.
---
.../dialects/_structured_transform_ops_ext.py | 47 +++++++------------
1 file changed, 18 insertions(+), 29 deletions(-)
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index f368e56f9818915..212fbc5badcbce8 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -4,7 +4,6 @@
try:
from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl, transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -101,7 +100,7 @@ def _dispatch_mixed_values(
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
- dynamic_values.append(_get_op_result_or_value(size))
+ dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)
return (dynamic_values, packed_values, static_values)
@@ -204,9 +203,7 @@ class DecomposeOp:
"""Specialization for DecomposeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- super().__init__(
- pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
- )
+ super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
class FuseIntoContainingOp:
@@ -277,9 +274,7 @@ class GeneralizeOp:
"""Specialization for GeneralizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- super().__init__(
- pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
- )
+ super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
class InterchangeOp:
@@ -296,7 +291,7 @@ def __init__(
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type,
- _get_op_result_or_value(target),
+ target,
iterator_interchange=iterator_interchange,
loc=loc,
ip=ip,
@@ -415,7 +410,7 @@ def match_op_names(
loc=None,
ip=None,
):
- ...
+ ...
@overload
@classmethod
@@ -428,7 +423,7 @@ def match_op_names(
loc=None,
ip=None,
):
- ...
+ ...
@classmethod
def match_op_names(
@@ -441,20 +436,20 @@ def match_op_names(
ip=None,
):
if isinstance(result_type_or_target, Type):
- result_type = result_type_or_target
- target = target_or_names
- names = names_or_none
+ result_type = result_type_or_target
+ target = target_or_names
+ names = names_or_none
else:
- result_type = transform.AnyOpType.get()
- target = result_type_or_target
- names = target_or_names
+ result_type = transform.AnyOpType.get()
+ target = result_type_or_target
+ names = target_or_names
if isinstance(names, str):
- names = [names]
+ names = [names]
return cls(
result_type,
- _get_op_result_or_value(target),
+ target,
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
loc=loc,
ip=ip,
@@ -479,7 +474,7 @@ def __init__(
result_type,
result_type,
result_type,
- _get_op_result_or_value(target),
+ target,
dimension=dimension,
target_size=target_size,
divisor=divisor,
@@ -530,9 +525,7 @@ class ScalarizeOp:
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
pdl_operation_type = pdl.OperationType.get()
- super().__init__(
- pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
- )
+ super().__init__(pdl_operation_type, target, loc=loc, ip=ip)
class SplitOp:
@@ -552,9 +545,7 @@ def __init__(
dynamic_split_point = None
else:
static_split_point = ShapedType.get_dynamic_size()
- dynamic_split_point = _get_op_result_or_value(split_point)
-
- target = _get_op_result_or_value(target)
+ dynamic_split_point = split_point
super().__init__(
target.type,
@@ -626,8 +617,6 @@ def __init__(
)
target = target_or_none
- target = _get_op_result_or_value(target)
-
super().__init__(
target.type,
loop_types,
@@ -750,7 +739,7 @@ def __init__(
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type,
- _get_op_result_or_value(target),
+ target,
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
vectorize_nd_extract=vectorize_nd_extract,
More information about the Mlir-commits
mailing list