[Mlir-commits] [mlir] [mlir][python] move transform extras (PR #76102)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 20 13:31:51 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
Addresses https://github.com/llvm/llvm-project/pull/75073#discussion_r1432182914.
---
Full diff: https://github.com/llvm/llvm-project/pull/76102.diff
4 Files Affected:
- (modified) mlir/python/CMakeLists.txt (+1-1)
- (modified) mlir/python/mlir/dialects/transform/__init__.py (+1)
- (renamed) mlir/python/mlir/dialects/transform/extras/__init__.py (+22-21)
- (modified) mlir/test/python/dialects/transform_extras.py (+1-1)
``````````diff
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..55c5973e40e525 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -172,7 +172,7 @@ declare_mlir_python_sources(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
GEN_ENUM_BINDINGS
SOURCES
- extras/dialects/transform/__init__.py)
+ dialects/transform/extras/__init__.py)
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 7ae4fefbac4121..175634c7d458f1 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -6,6 +6,7 @@
from .._transform_ops_gen import *
from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
+from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
try:
from ...ir import *
diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
similarity index 80%
rename from mlir/python/mlir/extras/dialects/transform/__init__.py
rename to mlir/python/mlir/dialects/transform/extras/__init__.py
index 9e313324318aa6..c715dac1ef7eb8 100644
--- a/mlir/python/mlir/extras/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,12 +2,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from __future__ import annotations
-from typing import Callable, Optional, Sequence
+from typing import Callable, Optional, Sequence, Union
from .... import ir
-from ....dialects import transform
-from ....dialects.transform import structured
+from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
+from .. import structured
class Handle(ir.Value):
@@ -25,16 +24,16 @@ def __init__(
self,
v: ir.Value,
*,
- parent: Optional[Handle] = None,
- children: Optional[Sequence[Handle]] = None,
+ parent: Optional["Handle"] = None,
+ children: Optional[Sequence["Handle"]] = None,
):
super().__init__(v)
self.parent = parent
self.children = children if children is not None else []
- at ir.register_value_caster(transform.AnyOpType.get_static_typeid())
- at ir.register_value_caster(transform.OperationType.get_static_typeid())
+ at ir.register_value_caster(AnyOpType.get_static_typeid())
+ at ir.register_value_caster(OperationType.get_static_typeid())
class OpHandle(Handle):
"""
Wrapper around a transform operation handle with methods to chain further
@@ -52,11 +51,13 @@ def __init__(
def match_ops(
self,
- ops: str
- | ir.OpView
- | structured.MatchInterfaceEnum
- | Sequence[str | ir.OpView],
- ) -> OpHandle:
+ ops: Union[
+ str,
+ ir.OpView,
+ structured.MatchInterfaceEnum,
+ Sequence[Union[str, ir.OpView]],
+ ],
+ ) -> "OpHandle":
"""
Emits a `transform.structured.MatchOp`.
Returns a handle to payload ops that match the given names, types, or
@@ -70,7 +71,7 @@ def match_ops(
if isinstance(ops, str):
ops = structured.MatchInterfaceEnum[ops]
match_op = structured.MatchOp(
- transform.AnyOpType.get(),
+ AnyOpType.get(),
self,
interface=ops,
)
@@ -78,15 +79,15 @@ def match_ops(
# Handle op name(s), either given directly as string or given as op.
else:
if isinstance(ops, str):
- op_type = transform.OperationType.get(ops)
+ op_type = OperationType.get(ops)
op_names = [ops]
elif isinstance(ops, Sequence):
- op_type = transform.AnyOpType.get()
+ op_type = AnyOpType.get()
op_names = [
op if isinstance(op, str) else op.OPERATION_NAME for op in ops
]
else:
- op_type = transform.OperationType.get(ops.OPERATION_NAME)
+ op_type = OperationType.get(ops.OPERATION_NAME)
op_names = [ops.OPERATION_NAME]
match_op = structured.MatchOp.match_op_names(
op_type,
@@ -100,7 +101,7 @@ def match_ops(
def insert_transform_script(
- block_or_insertion_point: ir.Block | ir.InsertionPoint,
+ block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
script: Callable[[OpHandle], None],
dump_script: bool = False,
) -> None:
@@ -137,12 +138,12 @@ def test_match_ops_single(module: OpHandle):
with context, ir.Location.unknown(context):
with insertion_point:
- named_sequence_op = transform.NamedSequenceOp(
- "__transform_main", [transform.AnyOpType.get()], []
+ named_sequence_op = NamedSequenceOp(
+ "__transform_main", [AnyOpType.get()], []
)
with ir.InsertionPoint(named_sequence_op.body):
script(named_sequence_op.bodyTarget)
- transform.YieldOp([])
+ YieldOp([])
if dump_script:
print(named_sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index dbfa8a2dc73c41..e7b43ea63c31ca 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -4,7 +4,7 @@
from mlir import ir
from mlir.dialects import scf
from mlir.dialects.transform import structured
-from mlir.extras.dialects.transform import OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
def build_transform_script(script: Callable[[OpHandle], None]):
``````````
</details>
https://github.com/llvm/llvm-project/pull/76102
More information about the Mlir-commits
mailing list