[Mlir-commits] [mlir] [mlir][python] move transform extras (PR #76102)
Maksim Levental
llvmlistbot at llvm.org
Wed Dec 20 13:20:14 PST 2023
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/76102
None
>From b774189888db0ab1ee1a18698400b8eca110d469 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 14:58:07 -0600
Subject: [PATCH 1/2] [mlir][python] move transform extras to dialects
---
mlir/python/CMakeLists.txt | 2 +-
.../mlir/dialects/transform/__init__.py | 1 +
.../transform/extras}/__init__.py | 22 +++++++++----------
3 files changed, 13 insertions(+), 12 deletions(-)
rename mlir/python/mlir/{extras/dialects/transform => dialects/transform/extras}/__init__.py (87%)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..55c5973e40e525 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -172,7 +172,7 @@ declare_mlir_python_sources(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
GEN_ENUM_BINDINGS
SOURCES
- extras/dialects/transform/__init__.py)
+ dialects/transform/extras/__init__.py)
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 7ae4fefbac4121..175634c7d458f1 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -6,6 +6,7 @@
from .._transform_ops_gen import *
from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
+from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
try:
from ...ir import *
diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
similarity index 87%
rename from mlir/python/mlir/extras/dialects/transform/__init__.py
rename to mlir/python/mlir/dialects/transform/extras/__init__.py
index 9e313324318aa6..8c69f12e54e36e 100644
--- a/mlir/python/mlir/extras/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -6,8 +6,8 @@
from typing import Callable, Optional, Sequence
from .... import ir
-from ....dialects import transform
-from ....dialects.transform import structured
+from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
+from .. import structured
class Handle(ir.Value):
@@ -33,8 +33,8 @@ def __init__(
self.children = children if children is not None else []
- at ir.register_value_caster(transform.AnyOpType.get_static_typeid())
- at ir.register_value_caster(transform.OperationType.get_static_typeid())
+ at ir.register_value_caster(AnyOpType.get_static_typeid())
+ at ir.register_value_caster(OperationType.get_static_typeid())
class OpHandle(Handle):
"""
Wrapper around a transform operation handle with methods to chain further
@@ -70,7 +70,7 @@ def match_ops(
if isinstance(ops, str):
ops = structured.MatchInterfaceEnum[ops]
match_op = structured.MatchOp(
- transform.AnyOpType.get(),
+ AnyOpType.get(),
self,
interface=ops,
)
@@ -78,15 +78,15 @@ def match_ops(
# Handle op name(s), either given directly as string or given as op.
else:
if isinstance(ops, str):
- op_type = transform.OperationType.get(ops)
+ op_type = OperationType.get(ops)
op_names = [ops]
elif isinstance(ops, Sequence):
- op_type = transform.AnyOpType.get()
+ op_type = AnyOpType.get()
op_names = [
op if isinstance(op, str) else op.OPERATION_NAME for op in ops
]
else:
- op_type = transform.OperationType.get(ops.OPERATION_NAME)
+ op_type = OperationType.get(ops.OPERATION_NAME)
op_names = [ops.OPERATION_NAME]
match_op = structured.MatchOp.match_op_names(
op_type,
@@ -137,12 +137,12 @@ def test_match_ops_single(module: OpHandle):
with context, ir.Location.unknown(context):
with insertion_point:
- named_sequence_op = transform.NamedSequenceOp(
- "__transform_main", [transform.AnyOpType.get()], []
+ named_sequence_op = NamedSequenceOp(
+ "__transform_main", [AnyOpType.get()], []
)
with ir.InsertionPoint(named_sequence_op.body):
script(named_sequence_op.bodyTarget)
- transform.YieldOp([])
+ YieldOp([])
if dump_script:
print(named_sequence_op)
>From bbf7f34d7aa7e41d1b802269a3cb8b9cb67d4c62 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 15:19:40 -0600
Subject: [PATCH 2/2] fix tests
---
mlir/test/python/dialects/transform_extras.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index dbfa8a2dc73c41..e7b43ea63c31ca 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -4,7 +4,7 @@
from mlir import ir
from mlir.dialects import scf
from mlir.dialects.transform import structured
-from mlir.extras.dialects.transform import OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
def build_transform_script(script: Callable[[OpHandle], None]):
More information about the Mlir-commits
mailing list