[Mlir-commits] [mlir] [mlir][python] move transform extras (PR #76102)

Maksim Levental llvmlistbot at llvm.org
Wed Dec 20 13:23:00 PST 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/76102

>From b774189888db0ab1ee1a18698400b8eca110d469 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 14:58:07 -0600
Subject: [PATCH 1/3] [mlir][python] move transform extras to dialects

---
 mlir/python/CMakeLists.txt                    |  2 +-
 .../mlir/dialects/transform/__init__.py       |  1 +
 .../transform/extras}/__init__.py             | 22 +++++++++----------
 3 files changed, 13 insertions(+), 12 deletions(-)
 rename mlir/python/mlir/{extras/dialects/transform => dialects/transform/extras}/__init__.py (87%)

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..55c5973e40e525 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -172,7 +172,7 @@ declare_mlir_python_sources(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   GEN_ENUM_BINDINGS
   SOURCES
-    extras/dialects/transform/__init__.py)
+    dialects/transform/extras/__init__.py)
 
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 7ae4fefbac4121..175634c7d458f1 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -6,6 +6,7 @@
 from .._transform_ops_gen import *
 from .._transform_ops_gen import _Dialect
 from ..._mlir_libs._mlirDialectsTransform import *
+from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
 
 try:
     from ...ir import *
diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
similarity index 87%
rename from mlir/python/mlir/extras/dialects/transform/__init__.py
rename to mlir/python/mlir/dialects/transform/extras/__init__.py
index 9e313324318aa6..8c69f12e54e36e 100644
--- a/mlir/python/mlir/extras/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -6,8 +6,8 @@
 from typing import Callable, Optional, Sequence
 
 from .... import ir
-from ....dialects import transform
-from ....dialects.transform import structured
+from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
+from .. import structured
 
 
 class Handle(ir.Value):
@@ -33,8 +33,8 @@ def __init__(
         self.children = children if children is not None else []
 
 
- at ir.register_value_caster(transform.AnyOpType.get_static_typeid())
- at ir.register_value_caster(transform.OperationType.get_static_typeid())
+ at ir.register_value_caster(AnyOpType.get_static_typeid())
+ at ir.register_value_caster(OperationType.get_static_typeid())
 class OpHandle(Handle):
     """
     Wrapper around a transform operation handle with methods to chain further
@@ -70,7 +70,7 @@ def match_ops(
             if isinstance(ops, str):
                 ops = structured.MatchInterfaceEnum[ops]
             match_op = structured.MatchOp(
-                transform.AnyOpType.get(),
+                AnyOpType.get(),
                 self,
                 interface=ops,
             )
@@ -78,15 +78,15 @@ def match_ops(
         # Handle op name(s), either given directly as string or given as op.
         else:
             if isinstance(ops, str):
-                op_type = transform.OperationType.get(ops)
+                op_type = OperationType.get(ops)
                 op_names = [ops]
             elif isinstance(ops, Sequence):
-                op_type = transform.AnyOpType.get()
+                op_type = AnyOpType.get()
                 op_names = [
                     op if isinstance(op, str) else op.OPERATION_NAME for op in ops
                 ]
             else:
-                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_type = OperationType.get(ops.OPERATION_NAME)
                 op_names = [ops.OPERATION_NAME]
             match_op = structured.MatchOp.match_op_names(
                 op_type,
@@ -137,12 +137,12 @@ def test_match_ops_single(module: OpHandle):
 
     with context, ir.Location.unknown(context):
         with insertion_point:
-            named_sequence_op = transform.NamedSequenceOp(
-                "__transform_main", [transform.AnyOpType.get()], []
+            named_sequence_op = NamedSequenceOp(
+                "__transform_main", [AnyOpType.get()], []
             )
         with ir.InsertionPoint(named_sequence_op.body):
             script(named_sequence_op.bodyTarget)
-            transform.YieldOp([])
+            YieldOp([])
 
     if dump_script:
         print(named_sequence_op)

>From bbf7f34d7aa7e41d1b802269a3cb8b9cb67d4c62 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 15:19:40 -0600
Subject: [PATCH 2/3] fix tests

---
 mlir/test/python/dialects/transform_extras.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index dbfa8a2dc73c41..e7b43ea63c31ca 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -4,7 +4,7 @@
 from mlir import ir
 from mlir.dialects import scf
 from mlir.dialects.transform import structured
-from mlir.extras.dialects.transform import OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
 
 
 def build_transform_script(script: Callable[[OpHandle], None]):

>From ba34bd2dbe21979d9009313c68bb230b9af99537 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 15:22:49 -0600
Subject: [PATCH 3/3] replace pipes for type hints

---
 .../dialects/transform/extras/__init__.py     | 21 ++++++++++---------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8c69f12e54e36e..c715dac1ef7eb8 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,8 +2,7 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from __future__ import annotations
-from typing import Callable, Optional, Sequence
+from typing import Callable, Optional, Sequence, Union
 
 from .... import ir
 from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
@@ -25,8 +24,8 @@ def __init__(
         self,
         v: ir.Value,
         *,
-        parent: Optional[Handle] = None,
-        children: Optional[Sequence[Handle]] = None,
+        parent: Optional["Handle"] = None,
+        children: Optional[Sequence["Handle"]] = None,
     ):
         super().__init__(v)
         self.parent = parent
@@ -52,11 +51,13 @@ def __init__(
 
     def match_ops(
         self,
-        ops: str
-        | ir.OpView
-        | structured.MatchInterfaceEnum
-        | Sequence[str | ir.OpView],
-    ) -> OpHandle:
+        ops: Union[
+            str,
+            ir.OpView,
+            structured.MatchInterfaceEnum,
+            Sequence[Union[str, ir.OpView]],
+        ],
+    ) -> "OpHandle":
         """
         Emits a `transform.structured.MatchOp`.
         Returns a handle to payload ops that match the given names, types, or
@@ -100,7 +101,7 @@ def match_ops(
 
 
 def insert_transform_script(
-    block_or_insertion_point: ir.Block | ir.InsertionPoint,
+    block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
     script: Callable[[OpHandle], None],
     dump_script: bool = False,
 ) -> None:



More information about the Mlir-commits mailing list