[Mlir-commits] [mlir] be6e9df - [mlir][transform][linalg][python] Add extended TileToForallOp.

Ingo Müller llvmlistbot at llvm.org
Wed Jul 19 07:02:34 PDT 2023


Author: Ingo Müller
Date: 2023-07-19T14:02:29Z
New Revision: be6e9df11f880ab128aef6550c6911d9f091e7d7

URL: https://github.com/llvm/llvm-project/commit/be6e9df11f880ab128aef6550c6911d9f091e7d7
DIFF: https://github.com/llvm/llvm-project/commit/be6e9df11f880ab128aef6550c6911d9f091e7d7.diff

LOG: [mlir][transform][linalg][python] Add extended TileToForallOp.

This patch adds a mixin for TileToForallOp to
_structured_transform_ops_ext.py with syntactic sugar for construction
such ops. First, the types of the results are made optional and filled
with common default values if omitted. Second, for num_threads and
tile_sizes, the three possible forms (static, dynamic, or packed), can
now all be given through the same respective argument, which gets
dispatched to the correct form-specific argument automatically.

Reviewed By: nicolasvasilache, ftynse

Differential Revision: https://reviews.llvm.org/D155090

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 640730997f93e1..7f90a464741b75 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -9,7 +9,7 @@
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import List, Optional, Sequence, Union, overload
+from typing import List, Optional, Sequence, Tuple, Union, overload
 
 IntOrAttrList = Sequence[Union[IntegerAttr, int]]
 OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
@@ -17,6 +17,47 @@
 BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
 OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
 
+MixedValues = Union[
+    Sequence[Union[int, IntegerAttr, Operation, Value, OpView]],
+    ArrayAttr,
+    Operation,
+    Value,
+    OpView,
+]
+
+
+# Dispatches `MixedValues` that all represents integers in various forms into
+# the following three categories:
+#   - `dynamic_values`: a list of `Value`s, potentially from op results;
+#   - `packed_values`: a value handle, potentially from an op result, associated
+#                      to one or more payload operations of integer type;
+#   - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
+#                      `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
+# The input is in the form for `packed_values`, only that result is set and the
+# other two are empty. Otherwise, the input can be a mix of the other two forms,
+# and for each dynamic value, a special value is added to the `static_values`.
+def _dispatch_mixed_values(
+    values: MixedValues,
+) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
+    dynamic_values = []
+    packed_values = None
+    static_values = None
+    if isinstance(values, ArrayAttr):
+        static_values = values
+    elif isinstance(values, (Operation, Value, OpView)):
+        packed_values = values
+    else:
+        static_values = []
+        for size in values or []:
+            if isinstance(size, int):
+                static_values.append(size)
+            else:
+                static_values.append(ShapedType.get_dynamic_size())
+                dynamic_values.append(_get_op_result_or_value(size))
+        static_values = DenseI64ArrayAttr.get(static_values)
+
+    return (dynamic_values, packed_values, static_values)
+
 
 def _get_int_int_array_attr(
     values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
@@ -354,6 +395,98 @@ def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
     return [element for element in attr]
 
 
+class TileToForallOp:
+    """Specialization for TileToForallOp class."""
+
+    @overload
+    def __init__(
+        self,
+        loops_type: Type,
+        tiled_op_type: Type,
+        target: Union[Operation, Value, OpView],
+        *,
+        num_threads: Optional[MixedValues] = None,
+        tile_sizes: MixedValues = None,
+        mapping=None,
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        target: Union[Operation, Value, OpView],
+        *,
+        num_threads: Optional[MixedValues] = None,
+        tile_sizes: MixedValues = None,
+        mapping=None,
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    def __init__(
+        self,
+        loops_type_or_target: Union[
+            Type, Union[Operation, Value, OpView]  # loops_type
+        ],  # target
+        tiled_op_type_or_none: Optional[Type] = None,
+        target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+        *,
+        num_threads: MixedValues = None,
+        tile_sizes: MixedValues = None,
+        mapping=None,
+        loc=None,
+        ip=None,
+    ):
+        # `Type` arguments in the front are optional: add default values to front.
+        if isinstance(loops_type_or_target, Type):
+            # First overload: type arguments provided.
+            if not isinstance(tiled_op_type_or_none, Type):
+                raise TypeError(
+                    "If 'loops_type_or_target' is a type, then "
+                    "'tiled_op_type_or_none' is expected to be one as well."
+                )
+            loops_type = loops_type_or_target
+            tiled_op_type = tiled_op_type_or_none
+            target = target_or_none
+        else:
+            # Last overload: type arguments missing.
+            loops_type = transform.AnyOpType.get()
+            tiled_op_type = transform.AnyOpType.get()
+            target = loops_type_or_target
+
+        # Unpack mixed num_threads.
+        (
+            dynamic_num_threads,
+            packed_num_threads,
+            num_threads_attr,
+        ) = _dispatch_mixed_values(num_threads)
+
+        # Unpack mixed tile_sizes.
+        (
+            dynamic_tile_sizes,
+            packed_tile_sizes,
+            tile_sizes_attr,
+        ) = _dispatch_mixed_values(tile_sizes)
+
+        super().__init__(
+            loops_type,
+            tiled_op_type,
+            target=target,
+            tile_sizes=dynamic_tile_sizes,
+            packed_tile_sizes=packed_tile_sizes,
+            static_tile_sizes=tile_sizes_attr,
+            num_threads=dynamic_num_threads,
+            packed_num_threads=packed_num_threads,
+            static_num_threads=num_threads_attr,
+            mapping=mapping,
+            loc=loc,
+            ip=ip,
+        )
+
+
 class VectorizeOp:
     """Specialization for VectorizeOp class."""
 

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 03a47166c824c6..1663ea3b7ae733 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -255,6 +255,98 @@ def testTileExplicitLoopTypeAll():
     # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
 
 
+ at run
+def testTileToForallCompact():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("linalg.matmul"),
+    )
+    with InsertionPoint(sequence.body):
+        structured.TileToForallOp(sequence.bodyTarget, num_threads=[2, 3, 4])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileToForallCompact
+    # CHECK: = transform.structured.tile_to_forall_op
+    # CHECK-SAME: num_threads [2, 3, 4] tile_sizes []
+    # CHECK-SAME: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)
+
+
+ at run
+def testTileToForallLoopsAndTileOpTypes():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.TileToForallOp(
+            transform.OperationType.get("scf.forall"),  # loops_type
+            transform.OperationType.get("linalg.matmul"),  # tiled_op_type
+            sequence.bodyTarget,
+            num_threads=[2, 3, 4],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes
+    # CHECK: = transform.structured.tile_to_forall_op
+    # CHECK-SAME: num_threads [2, 3, 4] tile_sizes []
+    # CHECK-SAME: (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.matmul">)
+
+
+ at run
+def testTileToForallTileSizes():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileToForallTileSizes
+    # CHECK: = transform.structured.tile_to_forall_op
+    # CHECK-SAME: num_threads [] tile_sizes [2, 3, 4]
+
+
+ at run
+def testTileToForallMixedDynamic():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
+        structured.TileToForallOp(sequence.bodyTarget, num_threads=[n, 3, 4])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileToForallMixedDynamic
+    # CHECK: = transform.structured.tile_to_forall_op
+    # CHECK-SAME: num_threads [%{{.*}} : !pdl.operation, 3, 4]
+
+
+ at run
+def testTileToForallMPackedDynamic():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
+        structured.TileToForallOp(sequence.bodyTarget, num_threads=n)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileToForallMPackedDynamic
+    # CHECK: = transform.structured.tile_to_forall_op
+    # CHECK-SAME: num_threads *(%0 : !pdl.operation)
+
+
+ at run
+def testTileToForallMapping():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
+        structured.TileToForallOp(
+            sequence.bodyTarget, num_threads=[2, 3], mapping=mapping
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileToForallMapping
+    # CHECK: = transform.structured.tile_to_forall_op
+    # CHECK-SAME: mapping = [#gpu.thread<y>, #gpu.thread<x>]
+
+
 @run
 def testVectorize():
     sequence = transform.SequenceOp(


        


More information about the Mlir-commits mailing list