[Mlir-commits] [mlir] 57c090b - [mlir][linalg][transform][python] Improve mix-in for PadOp.

Ingo Müller llvmlistbot at llvm.org
Mon Aug 21 06:35:55 PDT 2023


Author: Ingo Müller
Date: 2023-08-21T13:35:49Z
New Revision: 57c090b2ea03937e7c6a08a594532788d01bb813

URL: https://github.com/llvm/llvm-project/commit/57c090b2ea03937e7c6a08a594532788d01bb813
DIFF: https://github.com/llvm/llvm-project/commit/57c090b2ea03937e7c6a08a594532788d01bb813.diff

LOG: [mlir][linalg][transform][python] Improve mix-in for PadOp.

In particular:

* Fix and extend the support for constructing possibly nested ArrayAttrs
  from lists of Python ints. This can probably be generalized further
  and used in many more places.
* Add arguments for `pad_to_multiple_of` and `copy_back_op`.
* Format with black and reorder (keyword-only) arguments to match
  tablegen and (`*_gen.py`) order.
* Extend tests for new features.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D157789

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index b63652957d03f3..48dee5f801f18c 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -107,28 +107,60 @@ def _dispatch_mixed_values(
     return (dynamic_values, packed_values, static_values)
 
 
-def _get_int_int_array_attr(
+def _get_value_or_attribute_value(
+    value_or_attr: Union[any, Attribute, ArrayAttr]
+) -> any:
+    if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
+        return value_or_attr.value
+    if isinstance(value_or_attr, ArrayAttr):
+        return _get_value_list(value_or_attr)
+    return value_or_attr
+
+
+def _get_value_list(
+    sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
+) -> Sequence[any]:
+    return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
+
+
+def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
+    if values is None:
+        return ArrayAttr.get([])
+
+    # Turn into a Python list of Python ints.
+    values = _get_value_list(values)
+
+    # Make an ArrayAttr of IntegerAttrs out of it.
+    return ArrayAttr.get(
+        [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
+    )
+
+
+def _get_int_array_array_attr(
     values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
 ) -> ArrayAttr:
-    """Creates an array attribute containing array attributes of integers.
+    """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
 
-    If the operand is already an array attribute, forwards it. Otherwise treats
-    the operand as a list of attributes or integers, potentially interpserced, to
-    create a new array-of-array attribute. Expects the thread-local MLIR context
-    to have been set by the context manager.
+    The input has to be a collection of collection of integers, where any
+    Python Sequence and ArrayAttr are admissible collections and Python ints and
+    any IntegerAttr are admissible integers. Both levels of collections are
+    turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
+    If the input is None, an empty ArrayAttr is returned.
     """
     if values is None:
         return ArrayAttr.get([])
-    if isinstance(values, ArrayAttr):
-        return values
-    if isinstance(values, list):
-        values = [
-            ArrayAttr.get(
-                [IntegerAttr.get(IntegerType.get_signless(64), v) for v in value]
-            )
-            for value in values
-        ]
 
+    # Make sure the outer level is a list.
+    values = _get_value_list(values)
+
+    # The inner level is now either invalid or a mixed sequence of ArrayAttrs and
+    # Sequences. Make sure the nested values are all lists.
+    values = [_get_value_list(nested) for nested in values]
+
+    # Turn each nested list into an ArrayAttr.
+    values = [_get_int_array_attr(nested) for nested in values]
+
+    # Turn the outer list into an ArrayAttr.
     return ArrayAttr.get(values)
 
 
@@ -455,44 +487,55 @@ def __init__(
 
 
 class PadOp:
-  """Specialization for PadOp class."""
+    """Specialization for PadOp class."""
 
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      *,
-      padding_values: Optional[
-          Optional[Union[ArrayAttr, Sequence[Attribute]]]
-      ] = None,
-      padding_dimensions: OptionalIntList = None,
-      pack_paddings: OptionalIntList = None,
-      transpose_paddings: Optional[
-          Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
-      ] = None,
-      loc=None,
-      ip=None,
-  ):
-    if transpose_paddings is None:
-      transpose_paddings = []
-    if pack_paddings is None:
-      pack_paddings = []
-    if padding_dimensions is None:
-      padding_dimensions = []
-    if padding_values is None:
-      padding_values = []
-    pdl_operation_type = pdl.OperationType.get()
-    transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
-    super().__init__(
-        pdl_operation_type,
-        pdl_operation_type,
-        _get_op_result_or_value(target),
-        padding_values=padding_values,
-        padding_dimensions=padding_dimensions,
-        pack_paddings=pack_paddings,
-        transpose_paddings=transpose_paddings_attr,
-        loc=loc,
-        ip=ip,
-    )
+    def __init__(
+        self,
+        target: Union[Operation, OpView, Value],
+        *,
+        padding_values: Optional[
+            Union[ArrayAttr, Sequence[Union[bool, int, float, Attribute]]]
+        ] = None,
+        padding_dimensions: OptionalIntList = None,
+        pad_to_multiple_of: OptionalIntList = None,
+        pack_paddings: OptionalIntList = None,
+        transpose_paddings: Optional[
+            Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
+        ] = None,
+        copy_back_op: Optional[Union[str, StringAttr]] = None,
+        loc=None,
+        ip=None,
+    ):
+        if padding_values is None:
+            padding_values = []
+        if padding_dimensions is None:
+            padding_dimensions = []
+        if pad_to_multiple_of is None:
+            pad_to_multiple_of = []
+        if pack_paddings is None:
+            pack_paddings = []
+        if transpose_paddings is None:
+            transpose_paddings = []
+
+        padding_dimensions = _get_int_array_attr(padding_dimensions)
+        pad_to_multiple_of = _get_int_array_attr(pad_to_multiple_of)
+        pack_paddings = _get_int_array_attr(pack_paddings)
+        transpose_paddings = _get_int_array_array_attr(transpose_paddings)
+
+        pdl_operation_type = pdl.OperationType.get()
+        super().__init__(
+            pdl_operation_type,
+            pdl_operation_type,
+            target,
+            padding_values=padding_values,
+            padding_dimensions=padding_dimensions,
+            pad_to_multiple_of=pad_to_multiple_of,
+            pack_paddings=pack_paddings,
+            transpose_paddings=transpose_paddings,
+            copy_back_op=copy_back_op,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class ScalarizeOp:

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 12422ba29e2fef..2e3198b03d1d74 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -322,17 +322,22 @@ def testPad():
         structured.PadOp(
             sequence.bodyTarget,
             padding_values=[FloatAttr.get_f32(42.0)],
-            padding_dimensions=[1],
-            transpose_paddings=[[1, 0]],
+            padding_dimensions=Attribute.parse("[1]"),
+            pad_to_multiple_of=[128],
+            pack_paddings=[0],
+            transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
+            copy_back_op="linalg.copy",
         )
         transform.YieldOp()
     # CHECK-LABEL: TEST: testPad
     # CHECK: transform.sequence
     # CHECK: transform.structured.pad
-    # CHECK-DAG: padding_values = [4.200000e+01 : f32]
+    # CHECK-DAG: copy_back_op = "linalg.copy"
+    # CHECK-DAG: pack_paddings = [0]
+    # CHECK-DAG: pad_to_multiple_of = [128]
     # CHECK-DAG: padding_dimensions = [1]
-    # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
-    # (pack_paddings has default values)
+    # CHECK-DAG: padding_values = [4.200000e+01 : f32]
+    # CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [0, 1]]
 
 
 @run


        


More information about the Mlir-commits mailing list