[Mlir-commits] [mlir] a7fdb90 - [mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp.
Mehdi Amini
llvmlistbot at llvm.org
Mon Aug 14 09:15:35 PDT 2023
Author: Ingo Müller
Date: 2023-08-14T09:15:07-07:00
New Revision: a7fdb90bd4d8abeeef4ffbff19fe0020dd71818f
URL: https://github.com/llvm/llvm-project/commit/a7fdb90bd4d8abeeef4ffbff19fe0020dd71818f
DIFF: https://github.com/llvm/llvm-project/commit/a7fdb90bd4d8abeeef4ffbff19fe0020dd71818f.diff
LOG: [mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp.
Reviewed By: springerm
Re-land 691a2fab88a0f2c763bbd26de517dcde156c5188 which was incorrectly
reverted.
Differential Revision: https://reviews.llvm.org/D157706
Added:
Modified:
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/dialects/transform_structured_ext.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 9f623efb500173..675d4237031d98 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -187,6 +187,66 @@ def __init__(
)
+class MapCopyToThreadsOp:
+ """Specialization for MapCopyToThreadsOp class."""
+
+ @overload
+ def __init__(
+ self,
+ forall_op_type: Type,
+ tiled_op_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ total_num_threads: Union[int, IntegerAttr],
+ desired_bit_alignment: Union[int, IntegerAttr],
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ total_num_threads: Union[int, IntegerAttr],
+ desired_bit_alignment: Union[int, IntegerAttr],
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ forall_op_type_or_target: Union[Operation, OpView, Type, Value],
+ tiled_op_type_or_none: Optional[Type] = None,
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ total_num_threads: Union[int, IntegerAttr],
+ desired_bit_alignment: Union[int, IntegerAttr],
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(forall_op_type_or_target, Type):
+ forall_op_type = forall_op_type_or_target
+ tiled_op_type = tiled_op_type_or_none
+ target = target_or_none
+ else:
+ forall_op_type = transform.AnyOpType.get()
+ tiled_op_type = transform.AnyOpType.get()
+ target = forall_op_type_or_target
+
+ super().__init__(
+ forall_op_type,
+ tiled_op_type,
+ target,
+ total_num_threads=total_num_threads,
+ desired_bit_alignment=desired_bit_alignment,
+ loc=loc,
+ ip=ip,
+ )
+
+
class MatchOp:
"""Specialization for MatchOp class."""
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index cab4d2a03359cc..f4b5e2e9e948aa 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -97,6 +97,44 @@ def testInterchange():
# CHECK: iterator_interchange = [1, 0]
+ at run
+def testMapCopyToThreadsOpCompact():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.MapCopyToThreadsOp(
+ sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
+ # CHECK: = transform.structured.gpu.map_copy_to_threads
+ # CHECK-SAME: total_num_threads = 32
+ # CHECK-SAME: desired_bit_alignment = 128
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+
+ at run
+def testMapCopyToThreadsOpTypes():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.MapCopyToThreadsOp(
+ transform.OperationType.get("test.opA"),
+ transform.OperationType.get("test.opB"),
+ sequence.bodyTarget,
+ total_num_threads=32,
+ desired_bit_alignment=128,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
+ # CHECK: = transform.structured.gpu.map_copy_to_threads
+ # CHECK-SAME: total_num_threads = 32
+ # CHECK-SAME: desired_bit_alignment = 128
+ # CHECK-SAME: (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">)
+
+
@run
def testMatchOpNamesString():
sequence = transform.SequenceOp(
More information about the Mlir-commits
mailing list