[Mlir-commits] [mlir] 9442b44 - [mlir][linalg][transform][python] Fix optional args of PadOp mix-in.
Ingo Müller
llvmlistbot at llvm.org
Sat Sep 2 04:19:11 PDT 2023
Author: Ingo Müller
Date: 2023-09-02T11:19:06Z
New Revision: 9442b441c1c50e4e6782fd2e6aa16925c9d22e29
URL: https://github.com/llvm/llvm-project/commit/9442b441c1c50e4e6782fd2e6aa16925c9d22e29
DIFF: https://github.com/llvm/llvm-project/commit/9442b441c1c50e4e6782fd2e6aa16925c9d22e29.diff
LOG: [mlir][linalg][transform][python] Fix optional args of PadOp mix-in.
The mix-in did not allow to *not* set many of the arguments, even though
they represent optional attributes. Instead, it set default values,
which have different semantics in some cases. In other cases, setting
the default values is already done by the C++ layer, in which case they
are currently redundant and may be wrong in some potential future change
in the TD or C++ files. With this patch, `None` is preserved until the
generated binding, which handles them as desired.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D158844
Added:
Modified:
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/dialects/transform_structured_ext.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index a5b4e52d5ff065..544171dc2acb63 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -125,7 +125,7 @@ def _get_value_list(
def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
if values is None:
- return ArrayAttr.get([])
+ return None
# Turn into a Python list of Python ints.
values = _get_value_list(values)
@@ -148,7 +148,7 @@ def _get_int_array_array_attr(
If the input is None, an empty ArrayAttr is returned.
"""
if values is None:
- return ArrayAttr.get([])
+ return None
# Make sure the outer level is a list.
values = _get_value_list(values)
@@ -493,9 +493,7 @@ def __init__(
self,
target: Union[Operation, OpView, Value],
*,
- padding_values: Optional[
- Union[ArrayAttr, Sequence[Union[bool, int, float, Attribute]]]
- ] = None,
+ padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
padding_dimensions: OptionalIntList = None,
pad_to_multiple_of: OptionalIntList = None,
pack_paddings: OptionalIntList = None,
@@ -506,17 +504,6 @@ def __init__(
loc=None,
ip=None,
):
- if padding_values is None:
- padding_values = []
- if padding_dimensions is None:
- padding_dimensions = []
- if pad_to_multiple_of is None:
- pad_to_multiple_of = []
- if pack_paddings is None:
- pack_paddings = []
- if transpose_paddings is None:
- transpose_paddings = []
-
padding_dimensions = _get_int_array_attr(padding_dimensions)
pad_to_multiple_of = _get_int_array_attr(pad_to_multiple_of)
pack_paddings = _get_int_array_attr(pack_paddings)
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index e624d93f25d82e..84081161ab6a83 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -314,14 +314,33 @@ def testMultitileSizes():
@run
-def testPad():
+def testPadOpNoArgs():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.PadOp(sequence.bodyTarget)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testPadOpNoArgs
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.pad
+ # CHECK-NOT: copy_back_op
+ # CHECK-NOT: pack_paddings
+ # CHECK-NOT: pad_to_multiple_of
+ # CHECK-NOT: padding_dimensions
+ # CHECK-NOT: padding_values
+ # CHECK-NOT: transpose_paddings
+
+
+ at run
+def testPadOpArgs():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.PadOp(
sequence.bodyTarget,
- padding_values=[FloatAttr.get_f32(42.0)],
+ padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
padding_dimensions=Attribute.parse("[1]"),
pad_to_multiple_of=[128],
pack_paddings=[0],
@@ -329,7 +348,7 @@ def testPad():
copy_back_op="linalg.copy",
)
transform.YieldOp()
- # CHECK-LABEL: TEST: testPad
+ # CHECK-LABEL: TEST: testPadOpArgs
# CHECK: transform.sequence
# CHECK: transform.structured.pad
# CHECK-DAG: copy_back_op = "linalg.copy"
More information about the Mlir-commits
mailing list