[Mlir-commits] [mlir] a7fdb90 - [mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp.

Mehdi Amini llvmlistbot at llvm.org
Mon Aug 14 09:15:35 PDT 2023


Author: Ingo Müller
Date: 2023-08-14T09:15:07-07:00
New Revision: a7fdb90bd4d8abeeef4ffbff19fe0020dd71818f

URL: https://github.com/llvm/llvm-project/commit/a7fdb90bd4d8abeeef4ffbff19fe0020dd71818f
DIFF: https://github.com/llvm/llvm-project/commit/a7fdb90bd4d8abeeef4ffbff19fe0020dd71818f.diff

LOG: [mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp.

Reviewed By: springerm

Re-land 691a2fab88a0f2c763bbd26de517dcde156c5188 which was incorrectly
reverted.

Differential Revision: https://reviews.llvm.org/D157706

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 9f623efb500173..675d4237031d98 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -187,6 +187,66 @@ def __init__(
         )
 
 
+class MapCopyToThreadsOp:
+    """Specialization for MapCopyToThreadsOp class."""
+
+    @overload
+    def __init__(
+        self,
+        forall_op_type: Type,
+        tiled_op_type: Type,
+        target: Union[Operation, OpView, Value],
+        *,
+        total_num_threads: Union[int, IntegerAttr],
+        desired_bit_alignment: Union[int, IntegerAttr],
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        target: Union[Operation, OpView, Value],
+        *,
+        total_num_threads: Union[int, IntegerAttr],
+        desired_bit_alignment: Union[int, IntegerAttr],
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    def __init__(
+        self,
+        forall_op_type_or_target: Union[Operation, OpView, Type, Value],
+        tiled_op_type_or_none: Optional[Type] = None,
+        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        *,
+        total_num_threads: Union[int, IntegerAttr],
+        desired_bit_alignment: Union[int, IntegerAttr],
+        loc=None,
+        ip=None,
+    ):
+        if isinstance(forall_op_type_or_target, Type):
+            forall_op_type = forall_op_type_or_target
+            tiled_op_type = tiled_op_type_or_none
+            target = target_or_none
+        else:
+            forall_op_type = transform.AnyOpType.get()
+            tiled_op_type = transform.AnyOpType.get()
+            target = forall_op_type_or_target
+
+        super().__init__(
+            forall_op_type,
+            tiled_op_type,
+            target,
+            total_num_threads=total_num_threads,
+            desired_bit_alignment=desired_bit_alignment,
+            loc=loc,
+            ip=ip,
+        )
+
+
 class MatchOp:
     """Specialization for MatchOp class."""
 

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index cab4d2a03359cc..f4b5e2e9e948aa 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -97,6 +97,44 @@ def testInterchange():
     # CHECK: iterator_interchange = [1, 0]
 
 
+ at run
+def testMapCopyToThreadsOpCompact():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.MapCopyToThreadsOp(
+            sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
+    # CHECK: = transform.structured.gpu.map_copy_to_threads
+    # CHECK-SAME: total_num_threads = 32
+    # CHECK-SAME: desired_bit_alignment = 128
+    # CHECK-SAME:  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+
+ at run
+def testMapCopyToThreadsOpTypes():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.MapCopyToThreadsOp(
+            transform.OperationType.get("test.opA"),
+            transform.OperationType.get("test.opB"),
+            sequence.bodyTarget,
+            total_num_threads=32,
+            desired_bit_alignment=128,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
+    # CHECK: = transform.structured.gpu.map_copy_to_threads
+    # CHECK-SAME: total_num_threads = 32
+    # CHECK-SAME: desired_bit_alignment = 128
+    # CHECK-SAME:  (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">)
+
+
 @run
 def testMatchOpNamesString():
     sequence = transform.SequenceOp(


        


More information about the Mlir-commits mailing list