[Mlir-commits] [mlir] ea4a512 - [mlir][linalg][transform][python] Refactor TileOp mix-in.

Ingo Müller llvmlistbot at llvm.org
Mon Sep 4 04:32:18 PDT 2023


Author: Ingo Müller
Date: 2023-09-04T11:32:14Z
New Revision: ea4a5127c49d01e7415054e9ebeb813c98159810

URL: https://github.com/llvm/llvm-project/commit/ea4a5127c49d01e7415054e9ebeb813c98159810
DIFF: https://github.com/llvm/llvm-project/commit/ea4a5127c49d01e7415054e9ebeb813c98159810.diff

LOG: [mlir][linalg][transform][python] Refactor TileOp mix-in.

This patch simplifies and improves the mix-in of the `TileOp`. In
particular:

* Accept all types of sizes (static, dynamic, scalable) in a single
  argument `sizes`.
* Use the existing convenience function to dispatch different types of
  sizes instead of repeating the implementation in the mix-in.
* Pass on `None` values as is of optional arguments to the init function
  of the super class.
* Reformat with default indentation width (4 spaces vs 2 spaces).
* Add a a test for providing scalable sizes.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D159417

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 792629457552b6..44d9c1406fd89a 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -571,107 +571,77 @@ def __init__(
 
 
 class TileOp:
-  """Specialization for TileOp class."""
+    """Specialization for TileOp class."""
 
-  @overload
-  def __init__(
+    @overload
+    def __init__(
         self,
         loop_types: Union[Type, List[Type]],
         target: Union[Operation, Value],
         *,
-        sizes: Optional[
-            Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
-        ] = None,
+        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
         interchange: OptionalIntList = None,
-        scalable_sizes: OptionalBoolList = None,
         loc=None,
         ip=None,
     ):
-    ...
+        ...
 
-  @overload
-  def __init__(
+    @overload
+    def __init__(
         self,
         target: Union[Operation, Value, OpView],
         *,
-        sizes: Optional[
-            Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
-        ] = None,
+        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
         interchange: OptionalIntList = None,
-        scalable_sizes: OptionalBoolList = None,
         loc=None,
         ip=None,
     ):
-    ...
+        ...
 
-  def __init__(
+    def __init__(
         self,
         loop_types_or_target: Union[Type, List[Type], Operation, Value],
         target_or_none: Optional[Union[Operation, Value, OpView]] = None,
         *,
-        sizes: Optional[
-            Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
-        ] = None,
+        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
         interchange: OptionalIntList = None,
-        scalable_sizes: OptionalBoolList = None,
         loc=None,
         ip=None,
     ):
-    if interchange is None:
-      interchange = []
-    if sizes is None:
-      sizes = []
-
-    static_sizes = []
-    dynamic_sizes = []
-    if isinstance(sizes, ArrayAttr):
-      sizes_attr = sizes
-    else:
-      for size in sizes:
-        if isinstance(size, int):
-          static_sizes.append(size)
-        else:
-          static_sizes.append(ShapedType.get_dynamic_size())
-          dynamic_sizes.append(_get_op_result_or_value(size))
-      sizes_attr = DenseI64ArrayAttr.get(static_sizes)
+        (
+            dynamic_sizes,
+            static_sizes,
+            scalable_sizes,
+        ) = _dispatch_dynamic_index_list(sizes)
 
-    num_loops = sum(
-        v if v == 0 else 1 for v in self.__extract_values(sizes_attr)
-    )
-    if scalable_sizes is None:
-      scalable_sizes = [False] * len(self.__extract_values(sizes_attr))
+        num_loops = sum(v if v == 0 else 1 for v in static_sizes)
 
-    if isinstance(loop_types_or_target, (Operation, Value, OpView)):
-      loop_types = [transform.AnyOpType.get()] * num_loops
-      target = loop_types_or_target
-      assert target_or_none is None, "Cannot construct TileOp with two targets."
-    else:
-      loop_types = (
-          ([loop_types_or_target] * num_loops)
-          if isinstance(loop_types_or_target, Type)
-          else loop_types_or_target
-      )
-      target = target_or_none
+        if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+            loop_types = [transform.AnyOpType.get()] * num_loops
+            target = loop_types_or_target
+            assert target_or_none is None, "Cannot construct TileOp with two targets."
+        else:
+            loop_types = (
+                ([loop_types_or_target] * num_loops)
+                if isinstance(loop_types_or_target, Type)
+                else loop_types_or_target
+            )
+            target = target_or_none
 
-    target = _get_op_result_or_value(target)
+        target = _get_op_result_or_value(target)
 
-    super().__init__(
+        super().__init__(
             target.type,
             loop_types,
             target,
             dynamic_sizes=dynamic_sizes,
-            static_sizes=sizes_attr,
+            static_sizes=static_sizes,
             interchange=interchange,
             scalable_sizes=scalable_sizes,
             loc=loc,
             ip=ip,
         )
 
-  def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
-    if not attr:
-      return []
-    return [element for element in attr]
-
 
 class TileToForallOp:
     """Specialization for TileToForallOp class."""

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index acf14c197fec41..66e459917163c1 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -486,6 +486,22 @@ def testTileExplicitLoopTypeAll():
     # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
 
 
+ at run
+def testTileScalable():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.TileOp(
+            sequence.bodyTarget,
+            sizes=[4, [2]],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileScalable
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, [2]]
+
+
 @run
 def testTileToForallCompact():
     sequence = transform.SequenceOp(


        


More information about the Mlir-commits mailing list