[Mlir-commits] [mlir] 57c090b - [mlir][linalg][transform][python] Improve mix-in for PadOp.
Ingo Müller
llvmlistbot at llvm.org
Mon Aug 21 06:35:55 PDT 2023
Author: Ingo Müller
Date: 2023-08-21T13:35:49Z
New Revision: 57c090b2ea03937e7c6a08a594532788d01bb813
URL: https://github.com/llvm/llvm-project/commit/57c090b2ea03937e7c6a08a594532788d01bb813
DIFF: https://github.com/llvm/llvm-project/commit/57c090b2ea03937e7c6a08a594532788d01bb813.diff
LOG: [mlir][linalg][transform][python] Improve mix-in for PadOp.
In particular:
* Fix and extend the support for constructing possibly nested ArrayAttrs
from lists of Python ints. This can probably be generalized further
and used in many more places.
* Add arguments for `pad_to_multiple_of` and `copy_back_op`.
* Format with black and reorder (keyword-only) arguments to match
tablegen and (`*_gen.py`) order.
* Extend tests for new features.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D157789
Added:
Modified:
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/dialects/transform_structured_ext.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index b63652957d03f3..48dee5f801f18c 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -107,28 +107,60 @@ def _dispatch_mixed_values(
return (dynamic_values, packed_values, static_values)
-def _get_int_int_array_attr(
+def _get_value_or_attribute_value(
+ value_or_attr: Union[any, Attribute, ArrayAttr]
+) -> any:
+ if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
+ return value_or_attr.value
+ if isinstance(value_or_attr, ArrayAttr):
+ return _get_value_list(value_or_attr)
+ return value_or_attr
+
+
+def _get_value_list(
+ sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
+) -> Sequence[any]:
+ return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
+
+
+def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
+ if values is None:
+ return ArrayAttr.get([])
+
+ # Turn into a Python list of Python ints.
+ values = _get_value_list(values)
+
+ # Make an ArrayAttr of IntegerAttrs out of it.
+ return ArrayAttr.get(
+ [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
+ )
+
+
+def _get_int_array_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
- """Creates an array attribute containing array attributes of integers.
+ """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
- If the operand is already an array attribute, forwards it. Otherwise treats
- the operand as a list of attributes or integers, potentially interpserced, to
- create a new array-of-array attribute. Expects the thread-local MLIR context
- to have been set by the context manager.
+ The input has to be a collection of collection of integers, where any
+ Python Sequence and ArrayAttr are admissible collections and Python ints and
+ any IntegerAttr are admissible integers. Both levels of collections are
+ turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
+ If the input is None, an empty ArrayAttr is returned.
"""
if values is None:
return ArrayAttr.get([])
- if isinstance(values, ArrayAttr):
- return values
- if isinstance(values, list):
- values = [
- ArrayAttr.get(
- [IntegerAttr.get(IntegerType.get_signless(64), v) for v in value]
- )
- for value in values
- ]
+ # Make sure the outer level is a list.
+ values = _get_value_list(values)
+
+ # The inner level is now either invalid or a mixed sequence of ArrayAttrs and
+ # Sequences. Make sure the nested values are all lists.
+ values = [_get_value_list(nested) for nested in values]
+
+ # Turn each nested list into an ArrayAttr.
+ values = [_get_int_array_attr(nested) for nested in values]
+
+ # Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)
@@ -455,44 +487,55 @@ def __init__(
class PadOp:
- """Specialization for PadOp class."""
+ """Specialization for PadOp class."""
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- padding_values: Optional[
- Optional[Union[ArrayAttr, Sequence[Attribute]]]
- ] = None,
- padding_dimensions: OptionalIntList = None,
- pack_paddings: OptionalIntList = None,
- transpose_paddings: Optional[
- Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
- ] = None,
- loc=None,
- ip=None,
- ):
- if transpose_paddings is None:
- transpose_paddings = []
- if pack_paddings is None:
- pack_paddings = []
- if padding_dimensions is None:
- padding_dimensions = []
- if padding_values is None:
- padding_values = []
- pdl_operation_type = pdl.OperationType.get()
- transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
- super().__init__(
- pdl_operation_type,
- pdl_operation_type,
- _get_op_result_or_value(target),
- padding_values=padding_values,
- padding_dimensions=padding_dimensions,
- pack_paddings=pack_paddings,
- transpose_paddings=transpose_paddings_attr,
- loc=loc,
- ip=ip,
- )
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ padding_values: Optional[
+ Union[ArrayAttr, Sequence[Union[bool, int, float, Attribute]]]
+ ] = None,
+ padding_dimensions: OptionalIntList = None,
+ pad_to_multiple_of: OptionalIntList = None,
+ pack_paddings: OptionalIntList = None,
+ transpose_paddings: Optional[
+ Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
+ ] = None,
+ copy_back_op: Optional[Union[str, StringAttr]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if padding_values is None:
+ padding_values = []
+ if padding_dimensions is None:
+ padding_dimensions = []
+ if pad_to_multiple_of is None:
+ pad_to_multiple_of = []
+ if pack_paddings is None:
+ pack_paddings = []
+ if transpose_paddings is None:
+ transpose_paddings = []
+
+ padding_dimensions = _get_int_array_attr(padding_dimensions)
+ pad_to_multiple_of = _get_int_array_attr(pad_to_multiple_of)
+ pack_paddings = _get_int_array_attr(pack_paddings)
+ transpose_paddings = _get_int_array_array_attr(transpose_paddings)
+
+ pdl_operation_type = pdl.OperationType.get()
+ super().__init__(
+ pdl_operation_type,
+ pdl_operation_type,
+ target,
+ padding_values=padding_values,
+ padding_dimensions=padding_dimensions,
+ pad_to_multiple_of=pad_to_multiple_of,
+ pack_paddings=pack_paddings,
+ transpose_paddings=transpose_paddings,
+ copy_back_op=copy_back_op,
+ loc=loc,
+ ip=ip,
+ )
class ScalarizeOp:
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 12422ba29e2fef..2e3198b03d1d74 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -322,17 +322,22 @@ def testPad():
structured.PadOp(
sequence.bodyTarget,
padding_values=[FloatAttr.get_f32(42.0)],
- padding_dimensions=[1],
- transpose_paddings=[[1, 0]],
+ padding_dimensions=Attribute.parse("[1]"),
+ pad_to_multiple_of=[128],
+ pack_paddings=[0],
+ transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
+ copy_back_op="linalg.copy",
)
transform.YieldOp()
# CHECK-LABEL: TEST: testPad
# CHECK: transform.sequence
# CHECK: transform.structured.pad
- # CHECK-DAG: padding_values = [4.200000e+01 : f32]
+ # CHECK-DAG: copy_back_op = "linalg.copy"
+ # CHECK-DAG: pack_paddings = [0]
+ # CHECK-DAG: pad_to_multiple_of = [128]
# CHECK-DAG: padding_dimensions = [1]
- # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
- # (pack_paddings has default values)
+ # CHECK-DAG: padding_values = [4.200000e+01 : f32]
+ # CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [0, 1]]
@run
More information about the Mlir-commits
mailing list