[Mlir-commits] [mlir] 579ced4 - [MLIR][Python] Add structured.fuseop to python interpreter (#120601)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 3 03:22:02 PST 2025
Author: Hugo Trachino
Date: 2025-01-03T11:21:59Z
New Revision: 579ced4f8266b273d15b2801067a828151a222ef
URL: https://github.com/llvm/llvm-project/commit/579ced4f8266b273d15b2801067a828151a222ef
DIFF: https://github.com/llvm/llvm-project/commit/579ced4f8266b273d15b2801067a828151a222ef.diff
LOG: [MLIR][Python] Add structured.fuseop to python interpreter (#120601)
Implements a python interface for structured.fuseOp allowing more freedom with inputs.
Added:
Modified:
mlir/python/mlir/dialects/transform/structured.py
mlir/test/python/dialects/transform_structured_ext.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 9121aa8e40237b..bf40cc532065db 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -140,6 +140,77 @@ def __init__(
)
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuseOp(FuseOp):
+ """Specialization for FuseOp class."""
+
+ @overload
+ def __init__(
+ self,
+ loop_types: Union[Type, Sequence[Type]],
+ target: Union[Operation, Value, OpView],
+ *,
+ tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ tile_interchange: OptionalIntList = None,
+ apply_cleanup: Optional[bool] = False,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ tile_interchange: OptionalIntList = None,
+ apply_cleanup: Optional[bool] = False,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ tile_interchange: OptionalIntList = None,
+ apply_cleanup: Optional[bool] = False,
+ loc=None,
+ ip=None,
+ ):
+ tile_sizes = tile_sizes if tile_sizes else []
+ tile_interchange = tile_interchange if tile_interchange else []
+ _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
+ _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
+ num_loops = sum(0 if v == 0 else 1 for v in tile_sizes)
+
+ if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+ loop_types = [transform.AnyOpType.get()] * num_loops
+ target = loop_types_or_target
+ assert target_or_none is None, "Cannot construct FuseOp with two targets."
+ else:
+ loop_types = (
+ ([loop_types_or_target] * num_loops)
+ if isinstance(loop_types_or_target, Type)
+ else loop_types_or_target
+ )
+ target = target_or_none
+ super().__init__(
+ target.type,
+ loop_types,
+ target,
+ tile_sizes=tile_sizes,
+ tile_interchange=tile_interchange,
+ apply_cleanup=apply_cleanup,
+ loc=loc,
+ ip=ip,
+ )
+
+
@_ods_cext.register_operation(_Dialect, replace=True)
class GeneralizeOp(GeneralizeOp):
"""Specialization for GeneralizeOp class."""
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index fb4c75b5337928..8785d6d3600746 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -101,6 +101,42 @@ def testFuseIntoContainingOpCompact(target):
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ at run
+ at create_sequence
+def testFuseOpCompact(target):
+ structured.FuseOp(
+ target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True
+ )
+ # CHECK-LABEL: TEST: testFuseOpCompact
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK-SAME: interchange [0, 1] apply_cleanup = true
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+ at run
+ at create_sequence
+def testFuseOpNoArg(target):
+ structured.FuseOp(target)
+ # CHECK-LABEL: TEST: testFuseOpNoArg
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
+ # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
+ at run
+ at create_sequence
+def testFuseOpAttributes(target):
+ attr = DenseI64ArrayAttr.get([4, 8])
+ ichange = DenseI64ArrayAttr.get([0, 1])
+ structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
+ # CHECK-LABEL: TEST: testFuseOpAttributes
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK-SAME: interchange [0, 1]
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
@run
@create_sequence
def testGeneralize(target):
More information about the Mlir-commits
mailing list