[Mlir-commits] [mlir] [mlir][python] move transform extras (PR #76102)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 20 13:31:51 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

Addresses https://github.com/llvm/llvm-project/pull/75073#discussion_r1432182914.

---
Full diff: https://github.com/llvm/llvm-project/pull/76102.diff


4 Files Affected:

- (modified) mlir/python/CMakeLists.txt (+1-1) 
- (modified) mlir/python/mlir/dialects/transform/__init__.py (+1) 
- (renamed) mlir/python/mlir/dialects/transform/extras/__init__.py (+22-21) 
- (modified) mlir/test/python/dialects/transform_extras.py (+1-1) 


``````````diff
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..55c5973e40e525 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -172,7 +172,7 @@ declare_mlir_python_sources(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   GEN_ENUM_BINDINGS
   SOURCES
-    extras/dialects/transform/__init__.py)
+    dialects/transform/extras/__init__.py)
 
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 7ae4fefbac4121..175634c7d458f1 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -6,6 +6,7 @@
 from .._transform_ops_gen import *
 from .._transform_ops_gen import _Dialect
 from ..._mlir_libs._mlirDialectsTransform import *
+from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
 
 try:
     from ...ir import *
diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
similarity index 80%
rename from mlir/python/mlir/extras/dialects/transform/__init__.py
rename to mlir/python/mlir/dialects/transform/extras/__init__.py
index 9e313324318aa6..c715dac1ef7eb8 100644
--- a/mlir/python/mlir/extras/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,12 +2,11 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from __future__ import annotations
-from typing import Callable, Optional, Sequence
+from typing import Callable, Optional, Sequence, Union
 
 from .... import ir
-from ....dialects import transform
-from ....dialects.transform import structured
+from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
+from .. import structured
 
 
 class Handle(ir.Value):
@@ -25,16 +24,16 @@ def __init__(
         self,
         v: ir.Value,
         *,
-        parent: Optional[Handle] = None,
-        children: Optional[Sequence[Handle]] = None,
+        parent: Optional["Handle"] = None,
+        children: Optional[Sequence["Handle"]] = None,
     ):
         super().__init__(v)
         self.parent = parent
         self.children = children if children is not None else []
 
 
- at ir.register_value_caster(transform.AnyOpType.get_static_typeid())
- at ir.register_value_caster(transform.OperationType.get_static_typeid())
+ at ir.register_value_caster(AnyOpType.get_static_typeid())
+ at ir.register_value_caster(OperationType.get_static_typeid())
 class OpHandle(Handle):
     """
     Wrapper around a transform operation handle with methods to chain further
@@ -52,11 +51,13 @@ def __init__(
 
     def match_ops(
         self,
-        ops: str
-        | ir.OpView
-        | structured.MatchInterfaceEnum
-        | Sequence[str | ir.OpView],
-    ) -> OpHandle:
+        ops: Union[
+            str,
+            ir.OpView,
+            structured.MatchInterfaceEnum,
+            Sequence[Union[str, ir.OpView]],
+        ],
+    ) -> "OpHandle":
         """
         Emits a `transform.structured.MatchOp`.
         Returns a handle to payload ops that match the given names, types, or
@@ -70,7 +71,7 @@ def match_ops(
             if isinstance(ops, str):
                 ops = structured.MatchInterfaceEnum[ops]
             match_op = structured.MatchOp(
-                transform.AnyOpType.get(),
+                AnyOpType.get(),
                 self,
                 interface=ops,
             )
@@ -78,15 +79,15 @@ def match_ops(
         # Handle op name(s), either given directly as string or given as op.
         else:
             if isinstance(ops, str):
-                op_type = transform.OperationType.get(ops)
+                op_type = OperationType.get(ops)
                 op_names = [ops]
             elif isinstance(ops, Sequence):
-                op_type = transform.AnyOpType.get()
+                op_type = AnyOpType.get()
                 op_names = [
                     op if isinstance(op, str) else op.OPERATION_NAME for op in ops
                 ]
             else:
-                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_type = OperationType.get(ops.OPERATION_NAME)
                 op_names = [ops.OPERATION_NAME]
             match_op = structured.MatchOp.match_op_names(
                 op_type,
@@ -100,7 +101,7 @@ def match_ops(
 
 
 def insert_transform_script(
-    block_or_insertion_point: ir.Block | ir.InsertionPoint,
+    block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
     script: Callable[[OpHandle], None],
     dump_script: bool = False,
 ) -> None:
@@ -137,12 +138,12 @@ def test_match_ops_single(module: OpHandle):
 
     with context, ir.Location.unknown(context):
         with insertion_point:
-            named_sequence_op = transform.NamedSequenceOp(
-                "__transform_main", [transform.AnyOpType.get()], []
+            named_sequence_op = NamedSequenceOp(
+                "__transform_main", [AnyOpType.get()], []
             )
         with ir.InsertionPoint(named_sequence_op.body):
             script(named_sequence_op.bodyTarget)
-            transform.YieldOp([])
+            YieldOp([])
 
     if dump_script:
         print(named_sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index dbfa8a2dc73c41..e7b43ea63c31ca 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -4,7 +4,7 @@
 from mlir import ir
 from mlir.dialects import scf
 from mlir.dialects.transform import structured
-from mlir.extras.dialects.transform import OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
 
 
 def build_transform_script(script: Callable[[OpHandle], None]):

``````````

</details>


https://github.com/llvm/llvm-project/pull/76102


More information about the Mlir-commits mailing list