[Mlir-commits] [mlir] [mlir][python] meta region_op (PR #75673)
Maksim Levental
llvmlistbot at llvm.org
Wed Dec 20 13:47:22 PST 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/75673
>From b774189888db0ab1ee1a18698400b8eca110d469 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 14:58:07 -0600
Subject: [PATCH 1/4] [mlir][python] move transform extras to dialects
---
mlir/python/CMakeLists.txt | 2 +-
.../mlir/dialects/transform/__init__.py | 1 +
.../transform/extras}/__init__.py | 22 +++++++++----------
3 files changed, 13 insertions(+), 12 deletions(-)
rename mlir/python/mlir/{extras/dialects/transform => dialects/transform/extras}/__init__.py (87%)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..55c5973e40e525 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -172,7 +172,7 @@ declare_mlir_python_sources(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
GEN_ENUM_BINDINGS
SOURCES
- extras/dialects/transform/__init__.py)
+ dialects/transform/extras/__init__.py)
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 7ae4fefbac4121..175634c7d458f1 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -6,6 +6,7 @@
from .._transform_ops_gen import *
from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
+from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
try:
from ...ir import *
diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
similarity index 87%
rename from mlir/python/mlir/extras/dialects/transform/__init__.py
rename to mlir/python/mlir/dialects/transform/extras/__init__.py
index 9e313324318aa6..8c69f12e54e36e 100644
--- a/mlir/python/mlir/extras/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -6,8 +6,8 @@
from typing import Callable, Optional, Sequence
from .... import ir
-from ....dialects import transform
-from ....dialects.transform import structured
+from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
+from .. import structured
class Handle(ir.Value):
@@ -33,8 +33,8 @@ def __init__(
self.children = children if children is not None else []
- at ir.register_value_caster(transform.AnyOpType.get_static_typeid())
- at ir.register_value_caster(transform.OperationType.get_static_typeid())
+ at ir.register_value_caster(AnyOpType.get_static_typeid())
+ at ir.register_value_caster(OperationType.get_static_typeid())
class OpHandle(Handle):
"""
Wrapper around a transform operation handle with methods to chain further
@@ -70,7 +70,7 @@ def match_ops(
if isinstance(ops, str):
ops = structured.MatchInterfaceEnum[ops]
match_op = structured.MatchOp(
- transform.AnyOpType.get(),
+ AnyOpType.get(),
self,
interface=ops,
)
@@ -78,15 +78,15 @@ def match_ops(
# Handle op name(s), either given directly as string or given as op.
else:
if isinstance(ops, str):
- op_type = transform.OperationType.get(ops)
+ op_type = OperationType.get(ops)
op_names = [ops]
elif isinstance(ops, Sequence):
- op_type = transform.AnyOpType.get()
+ op_type = AnyOpType.get()
op_names = [
op if isinstance(op, str) else op.OPERATION_NAME for op in ops
]
else:
- op_type = transform.OperationType.get(ops.OPERATION_NAME)
+ op_type = OperationType.get(ops.OPERATION_NAME)
op_names = [ops.OPERATION_NAME]
match_op = structured.MatchOp.match_op_names(
op_type,
@@ -137,12 +137,12 @@ def test_match_ops_single(module: OpHandle):
with context, ir.Location.unknown(context):
with insertion_point:
- named_sequence_op = transform.NamedSequenceOp(
- "__transform_main", [transform.AnyOpType.get()], []
+ named_sequence_op = NamedSequenceOp(
+ "__transform_main", [AnyOpType.get()], []
)
with ir.InsertionPoint(named_sequence_op.body):
script(named_sequence_op.bodyTarget)
- transform.YieldOp([])
+ YieldOp([])
if dump_script:
print(named_sequence_op)
>From bbf7f34d7aa7e41d1b802269a3cb8b9cb67d4c62 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 15:19:40 -0600
Subject: [PATCH 2/4] fix tests
---
mlir/test/python/dialects/transform_extras.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index dbfa8a2dc73c41..e7b43ea63c31ca 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -4,7 +4,7 @@
from mlir import ir
from mlir.dialects import scf
from mlir.dialects.transform import structured
-from mlir.extras.dialects.transform import OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
def build_transform_script(script: Callable[[OpHandle], None]):
>From ba34bd2dbe21979d9009313c68bb230b9af99537 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 15:22:49 -0600
Subject: [PATCH 3/4] replace pipes for type hints
---
.../dialects/transform/extras/__init__.py | 21 ++++++++++---------
1 file changed, 11 insertions(+), 10 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8c69f12e54e36e..c715dac1ef7eb8 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,8 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from __future__ import annotations
-from typing import Callable, Optional, Sequence
+from typing import Callable, Optional, Sequence, Union
from .... import ir
from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
@@ -25,8 +24,8 @@ def __init__(
self,
v: ir.Value,
*,
- parent: Optional[Handle] = None,
- children: Optional[Sequence[Handle]] = None,
+ parent: Optional["Handle"] = None,
+ children: Optional[Sequence["Handle"]] = None,
):
super().__init__(v)
self.parent = parent
@@ -52,11 +51,13 @@ def __init__(
def match_ops(
self,
- ops: str
- | ir.OpView
- | structured.MatchInterfaceEnum
- | Sequence[str | ir.OpView],
- ) -> OpHandle:
+ ops: Union[
+ str,
+ ir.OpView,
+ structured.MatchInterfaceEnum,
+ Sequence[Union[str, ir.OpView]],
+ ],
+ ) -> "OpHandle":
"""
Emits a `transform.structured.MatchOp`.
Returns a handle to payload ops that match the given names, types, or
@@ -100,7 +101,7 @@ def match_ops(
def insert_transform_script(
- block_or_insertion_point: ir.Block | ir.InsertionPoint,
+ block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
script: Callable[[OpHandle], None],
dump_script: bool = False,
) -> None:
>From 49e75c59c8b6a12a1791253c9a1e207d230ae63b Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 15 Dec 2023 18:21:31 -0600
Subject: [PATCH 4/4] [mlir][python] meta region_op
---
mlir/python/CMakeLists.txt | 9 +-
mlir/python/mlir/dialects/func.py | 3 +
mlir/python/mlir/dialects/pdl.py | 4 +
.../mlir/dialects/transform/__init__.py | 8 +-
.../dialects/transform/extras/__init__.py | 14 ++-
mlir/python/mlir/extras/meta.py | 59 ++++++++++
mlir/test/python/dialects/transform_extras.py | 101 +++++++++++++++++-
7 files changed, 191 insertions(+), 7 deletions(-)
create mode 100644 mlir/python/mlir/extras/meta.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 55c5973e40e525..3c9cf304d88a27 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,7 +21,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
- extras/types.py
dialects/_ods_common.py
# The main _mlir module has submodules: include stubs from each.
@@ -30,6 +29,14 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/_mlir/passmanager.pyi
)
+declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ ADD_TO_PARENT MLIRPythonSources.Core.Python
+ SOURCES
+ extras/types.py
+ extras/meta.py
+)
+
declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
ADD_TO_PARENT MLIRPythonSources
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 6599f67b707877..24fdcbcd85b29f 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -243,6 +243,9 @@ def emit_call_op(*call_args):
return decorator
+func = FuncOp.from_py_func
+
+
@_ods_cext.register_operation(_Dialect, replace=True)
class CallOp(CallOp):
"""Specialization for the call op class."""
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index 90d7d706238e64..de239f23d9fa96 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -220,3 +220,7 @@ def __init__(
constantTypes = []
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
+
+
+def op_t():
+ return OperationType.get()
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 175634c7d458f1..435c1668d0d70a 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -175,7 +175,7 @@ def __init__(
result_types: Sequence[Type],
sym_visibility=None,
arg_attrs=None,
- res_attrs=None
+ res_attrs=None,
):
function_type = FunctionType.get(input_types, result_types)
super().__init__(
@@ -183,7 +183,7 @@ def __init__(
function_type=TypeAttr.get(function_type),
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
- res_attrs=res_attrs
+ res_attrs=res_attrs,
)
self.regions[0].blocks.append(*input_types)
@@ -212,3 +212,7 @@ def __init__(
if operands is None:
operands = []
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+
+
+def any_op_t():
+ return AnyOpType.get()
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index c715dac1ef7eb8..1e8c0652389226 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -4,8 +4,16 @@
from typing import Callable, Optional, Sequence, Union
+from ....extras.meta import region_op
from .... import ir
-from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
+from .. import (
+ AnyOpType,
+ OperationType,
+ NamedSequenceOp,
+ YieldOp,
+ SequenceOp,
+ ApplyPatternsOp,
+)
from .. import structured
@@ -147,3 +155,7 @@ def test_match_ops_single(module: OpHandle):
if dump_script:
print(named_sequence_op)
+
+
+sequence = region_op(SequenceOp.__base__, terminator=YieldOp)
+apply_patterns = region_op(ApplyPatternsOp)
diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py
new file mode 100644
index 00000000000000..dce61d80eeea60
--- /dev/null
+++ b/mlir/python/mlir/extras/meta.py
@@ -0,0 +1,59 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import inspect
+from functools import wraps
+
+from ..dialects._ods_common import get_op_result_or_op_results
+from ..ir import Type, InsertionPoint
+
+
+def op_region_builder(op, op_region, terminator=None):
+ def builder_wrapper(body_builder):
+ # add a block with block args having types ...
+ if len(op_region.blocks) == 0:
+ sig = inspect.signature(body_builder)
+ types = [p.annotation for p in sig.parameters.values()]
+ if not (
+ len(types) == len(sig.parameters)
+ and all(isinstance(t, Type) for t in types)
+ ):
+ raise ValueError(
+ f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
+ )
+
+ op_region.blocks.append(*types)
+
+ with InsertionPoint(op_region.blocks[0]):
+ results = body_builder(*list(op_region.blocks[0].arguments))
+
+ with InsertionPoint(list(op_region.blocks)[-1]):
+ if terminator is not None:
+ res = []
+ if isinstance(results, (tuple, list)):
+ res.extend(results)
+ elif results is not None:
+ res.append(results)
+ terminator(res)
+
+ return get_op_result_or_op_results(op)
+
+ return builder_wrapper
+
+
+def region_op(op_constructor, terminator=None):
+ def op_decorator(*args, **kwargs):
+ op = op_constructor(*args, **kwargs)
+ op_region = op.regions[0]
+
+ return op_region_builder(op, op_region, terminator)
+
+ @wraps(op_decorator)
+ def maybe_no_args(*args, **kwargs):
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
+ return op_decorator()(args[0])
+ else:
+ return op_decorator(*args, **kwargs)
+
+ return maybe_no_args
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index e7b43ea63c31ca..2345e5b375c6df 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -2,9 +2,34 @@
from typing import Callable
from mlir import ir
-from mlir.dialects import scf
-from mlir.dialects.transform import structured
-from mlir.dialects.transform.extras import OpHandle, insert_transform_script
+from mlir.dialects import scf, pdl, func, arith, linalg
+from mlir.dialects.transform import (
+ structured,
+ get_parent_op,
+ apply_patterns_canonicalization,
+ apply_cse,
+ any_op_t,
+)
+from mlir.dialects.transform import FailurePropagationMode
+from mlir.dialects.transform.structured import structured_match
+from mlir.dialects.transform.loop import loop_unroll
+from mlir.dialects.transform.extras import (
+ OpHandle,
+ insert_transform_script,
+ sequence,
+ apply_patterns,
+)
+from mlir.extras import types as T
+
+
+def construct_and_print_in_module(f):
+ print("\nTEST:", f.__name__)
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ with ir.InsertionPoint(module.body):
+ f()
+ print(module)
+ return f
def build_transform_script(script: Callable[[OpHandle], None]):
@@ -93,3 +118,73 @@ def test_match_ops_mixed(op: OpHandle):
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
# CHECK-SAME: -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_sequence_region
+ at construct_and_print_in_module
+def test_sequence_region():
+ # CHECK-LABEL: func.func @loop_unroll_op() {
+ # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
+ # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index
+ # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.func()
+ def loop_unroll_op():
+ for i in scf.for_(0, 42, 5):
+ v = arith.addi(i, i)
+ scf.yield_([])
+
+ # CHECK: transform.sequence failures(propagate) {
+ # CHECK: ^bb0(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
+ # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
+ # CHECK: }
+ @sequence([], FailurePropagationMode.Propagate, [])
+ def basic(target: any_op_t()):
+ m = structured_match(any_op_t(), target, ops=["arith.addi"])
+ loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
+ loop_unroll(loop, 4)
+
+
+# CHECK-LABEL: TEST: test_apply_patterns
+ at construct_and_print_in_module
+def test_apply_patterns():
+ M, N, K = 3, 5, 3
+
+ # CHECK-LABEL: func.func @matmul(
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
+ # CHECK: %[[VAL_3:.*]] = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
+ # CHECK: return %[[VAL_3]] : tensor<3x3xf32>
+ # CHECK: }
+ @func.func(
+ T.tensor(M, N, T.f32()), T.tensor(N, K, T.f32()), T.tensor(M, K, T.f32())
+ )
+ def matmul(A, B, C):
+ return linalg.matmul(A, B, outs=[C])
+
+ # CHECK: transform.sequence failures(propagate) {
+ # CHECK: ^bb0(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
+ # CHECK: apply_patterns to %[[VAL_2]] {
+ # CHECK: transform.apply_patterns.canonicalization
+ # CHECK: } : !pdl.operation
+ # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+ # CHECK: apply_cse to %[[VAL_3]] : !transform.any_op
+ # CHECK: }
+ @sequence([], FailurePropagationMode.Propagate, [])
+ def basic(variant_op: any_op_t()):
+ matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
+ top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
+
+ @apply_patterns(top_func)
+ def pats():
+ apply_patterns_canonicalization()
+
+ top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
+ apply_cse(top_func)
More information about the Mlir-commits
mailing list