[Mlir-commits] [mlir] ea4a512 - [mlir][linalg][transform][python] Refactor TileOp mix-in.
Ingo Müller
llvmlistbot at llvm.org
Mon Sep 4 04:32:18 PDT 2023
Author: Ingo Müller
Date: 2023-09-04T11:32:14Z
New Revision: ea4a5127c49d01e7415054e9ebeb813c98159810
URL: https://github.com/llvm/llvm-project/commit/ea4a5127c49d01e7415054e9ebeb813c98159810
DIFF: https://github.com/llvm/llvm-project/commit/ea4a5127c49d01e7415054e9ebeb813c98159810.diff
LOG: [mlir][linalg][transform][python] Refactor TileOp mix-in.
This patch simplifies and improves the mix-in of the `TileOp`. In
particular:
* Accept all types of sizes (static, dynamic, scalable) in a single
argument `sizes`.
* Use the existing convenience function to dispatch different types of
sizes instead of repeating the implementation in the mix-in.
* Pass on `None` values as is of optional arguments to the init function
of the super class.
* Reformat with default indentation width (4 spaces vs 2 spaces).
* Add a a test for providing scalable sizes.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D159417
Added:
Modified:
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/dialects/transform_structured_ext.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 792629457552b6..44d9c1406fd89a 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -571,107 +571,77 @@ def __init__(
class TileOp:
- """Specialization for TileOp class."""
+ """Specialization for TileOp class."""
- @overload
- def __init__(
+ @overload
+ def __init__(
self,
loop_types: Union[Type, List[Type]],
target: Union[Operation, Value],
*,
- sizes: Optional[
- Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
- ] = None,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
- scalable_sizes: OptionalBoolList = None,
loc=None,
ip=None,
):
- ...
+ ...
- @overload
- def __init__(
+ @overload
+ def __init__(
self,
target: Union[Operation, Value, OpView],
*,
- sizes: Optional[
- Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
- ] = None,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
- scalable_sizes: OptionalBoolList = None,
loc=None,
ip=None,
):
- ...
+ ...
- def __init__(
+ def __init__(
self,
loop_types_or_target: Union[Type, List[Type], Operation, Value],
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
- sizes: Optional[
- Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
- ] = None,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
- scalable_sizes: OptionalBoolList = None,
loc=None,
ip=None,
):
- if interchange is None:
- interchange = []
- if sizes is None:
- sizes = []
-
- static_sizes = []
- dynamic_sizes = []
- if isinstance(sizes, ArrayAttr):
- sizes_attr = sizes
- else:
- for size in sizes:
- if isinstance(size, int):
- static_sizes.append(size)
- else:
- static_sizes.append(ShapedType.get_dynamic_size())
- dynamic_sizes.append(_get_op_result_or_value(size))
- sizes_attr = DenseI64ArrayAttr.get(static_sizes)
+ (
+ dynamic_sizes,
+ static_sizes,
+ scalable_sizes,
+ ) = _dispatch_dynamic_index_list(sizes)
- num_loops = sum(
- v if v == 0 else 1 for v in self.__extract_values(sizes_attr)
- )
- if scalable_sizes is None:
- scalable_sizes = [False] * len(self.__extract_values(sizes_attr))
+ num_loops = sum(v if v == 0 else 1 for v in static_sizes)
- if isinstance(loop_types_or_target, (Operation, Value, OpView)):
- loop_types = [transform.AnyOpType.get()] * num_loops
- target = loop_types_or_target
- assert target_or_none is None, "Cannot construct TileOp with two targets."
- else:
- loop_types = (
- ([loop_types_or_target] * num_loops)
- if isinstance(loop_types_or_target, Type)
- else loop_types_or_target
- )
- target = target_or_none
+ if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+ loop_types = [transform.AnyOpType.get()] * num_loops
+ target = loop_types_or_target
+ assert target_or_none is None, "Cannot construct TileOp with two targets."
+ else:
+ loop_types = (
+ ([loop_types_or_target] * num_loops)
+ if isinstance(loop_types_or_target, Type)
+ else loop_types_or_target
+ )
+ target = target_or_none
- target = _get_op_result_or_value(target)
+ target = _get_op_result_or_value(target)
- super().__init__(
+ super().__init__(
target.type,
loop_types,
target,
dynamic_sizes=dynamic_sizes,
- static_sizes=sizes_attr,
+ static_sizes=static_sizes,
interchange=interchange,
scalable_sizes=scalable_sizes,
loc=loc,
ip=ip,
)
- def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
- if not attr:
- return []
- return [element for element in attr]
-
class TileToForallOp:
"""Specialization for TileToForallOp class."""
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index acf14c197fec41..66e459917163c1 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -486,6 +486,22 @@ def testTileExplicitLoopTypeAll():
# CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
+ at run
+def testTileScalable():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.TileOp(
+ sequence.bodyTarget,
+ sizes=[4, [2]],
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileScalable
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, [2]]
+
+
@run
def testTileToForallCompact():
sequence = transform.SequenceOp(
More information about the Mlir-commits
mailing list