[Mlir-commits] [mlir] [MLIR][transform][python] add sugared python abstractions for transform dialect (PR #75073)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 13 06:42:51 PST 2023


Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/75073 at github.com>


https://github.com/martin-luecke updated https://github.com/llvm/llvm-project/pull/75073

>From 5ec81e059c9c1a2be91782beca501a5e0d5df474 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Wed, 6 Dec 2023 16:32:56 +0100
Subject: [PATCH 1/5] [MLIR][transform][python] add extras for sugared python
 abstractions

---
 mlir/python/CMakeLists.txt                    |   8 ++
 .../dialects/transform/extras/__init__.py     | 113 ++++++++++++++++++
 mlir/test/python/dialects/transform_extras.py |  80 +++++++++++++
 3 files changed, 201 insertions(+)
 create mode 100644 mlir/python/mlir/dialects/transform/extras/__init__.py
 create mode 100644 mlir/test/python/dialects/transform_extras.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 585918afc26335..8013b49dbf9d68 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -311,6 +311,14 @@ declare_mlir_dialect_python_bindings(
     dialects/rocdl.py
   DIALECT_NAME rocdl)
 
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.transform.extras
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  GEN_ENUM_BINDINGS
+  SOURCES
+    dialects/transform/extras/__init__.py)
+
 declare_mlir_python_sources(
   MLIRPythonSources.Dialects.quant
   ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
new file mode 100644
index 00000000000000..0ccbb2e9254fd2
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -0,0 +1,113 @@
+from __future__ import annotations
+
+import abc
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence, TypeVar
+
+try:
+    from .... import ir
+    from ....dialects import transform
+    from ....dialects.transform import structured
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+
+ at dataclass
+class Value(abc.ABC):
+    """Wrapper around a transform value handle with methods to chain further transforms."""
+
+    _mlir_value: ir.Value
+    children: list[Value] = field(default_factory=list)
+    parent: Optional[Value] = None
+
+    @property
+    def mlir_value(self) -> ir.Value:
+        return self._mlir_value
+
+
+ at dataclass
+class Param(Value):
+    """Wrapper around a transform Param with methods to chain further transforms."""
+
+
+ at dataclass
+class OpHandle(Value):
+    """Wrapper around a transform OpHandle with methods to chain further transforms."""
+
+    def match_ops(
+        self,
+        ops: str
+        | ir.OpView
+        | structured.MatchInterfaceEnum
+        | Sequence[str | ir.OpView],
+    ) -> OpHandle:
+        """Returns a handle to ops that match the given names, types, or interface.
+
+        If only a single type is given, the value wrapped by the resulting
+        handle is populated with the respective type.
+        """
+        # Handle interface.
+        if isinstance(ops, structured.MatchInterfaceEnum) or (
+            isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+        ):
+            if isinstance(ops, str):
+                ops = structured.MatchInterfaceEnum[ops]
+            match_op = structured.MatchOp(
+                transform.AnyOpType.get(),
+                self.mlir_value,
+                interface=ops,
+            )
+
+        # Handle op name(s), either given directly as string or given as op.
+        else:
+            if isinstance(ops, str):
+                op_type = transform.OperationType.get(ops)
+                op_names = [ops]
+            elif isinstance(ops, Sequence):
+                op_type = transform.AnyOpType.get()
+                op_names = [
+                    op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+                ]
+            else:
+                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_names = [ops.OPERATION_NAME]
+            match_op = structured.MatchOp.match_op_names(
+                op_type,
+                self.mlir_value,
+                op_names,
+            )
+
+        handle = OpHandle(match_op.results_, parent=self)
+        self.children.append(handle)
+        return handle
+
+
+ValueT = TypeVar("ValueT", bound=Value)
+
+
+def insert_transform_script(
+    module: ir.Module,
+    script: Callable[[ValueT], None],
+    dump_script: bool = False,
+) -> None:
+    """Inserts the transform script of the schedule into the module.
+
+    Args:
+        module: Existing module into which the script should be inserted.
+        script: The transform script to apply at.
+        dump_script: Whether to dump the script after creation.
+    """
+    # Insert the script into the IR
+    with module.context, ir.Location.unknown(module.context):
+        with ir.InsertionPoint.at_block_begin(module.body):
+            sequence_op = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                (),
+                transform.AnyOpType.get(),
+            )
+        with ir.InsertionPoint(sequence_op.body):
+            script(OpHandle(sequence_op.bodyTarget))
+            transform.YieldOp([])
+
+    if dump_script:
+        print(sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
new file mode 100644
index 00000000000000..8c482018868319
--- /dev/null
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -0,0 +1,80 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Callable, TypeVar
+from mlir import ir
+from mlir.dialects import scf
+from mlir.dialects.transform import structured
+from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
+
+ValueT = TypeVar("ValueT", bound=Value)
+
+
+def build_transform_script(script: Callable[[ValueT], None]):
+    print("\nTEST:", script.__name__)
+    with ir.Context(), ir.Location.unknown():
+        module = ir.Module.create()
+        insert_transform_script(module, script=script, dump_script=True)
+        module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_match_ops_single
+ at build_transform_script
+def test_match_ops_single(op: OpHandle):
+    op.match_ops(scf.ForOp)
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
+    # CHECK-SAME:    in %[[VAL_0]]
+    # CHECK-SAME:      -> !transform.op<"scf.for">
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_name
+ at build_transform_script
+def test_match_ops_string_name(op: OpHandle):
+    op.match_ops("linalg.matmul")
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["linalg.matmul"]} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_iface
+ at build_transform_script
+def test_match_ops_string_iface(op: OpHandle):
+    op.match_ops("LinalgOp")
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_iface
+ at build_transform_script
+def test_match_ops_iface(op: OpHandle):
+    op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_multiple
+ at build_transform_script
+def test_match_ops_multiple(op: OpHandle):
+    op.match_ops([scf.ForOp, scf.ForallOp])
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
+    # CHECK-SAME:     -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_match_ops_mixed
+ at build_transform_script
+def test_match_ops_mixed(op: OpHandle):
+    op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
+    # CHECK-SAME:     -> !transform.any_op

>From 86fb474f795338fc667f34324debac679b924457 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:21:47 +0100
Subject: [PATCH 2/5] cleaning up comments

---
 .../mlir/dialects/transform/extras/__init__.py      | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 0ccbb2e9254fd2..4b864cf42a53b6 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -9,7 +9,7 @@
     from ....dialects import transform
     from ....dialects.transform import structured
 except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports") from e
 
 
 @dataclass
@@ -41,8 +41,8 @@ def match_ops(
         | structured.MatchInterfaceEnum
         | Sequence[str | ir.OpView],
     ) -> OpHandle:
-        """Returns a handle to ops that match the given names, types, or interface.
-
+        """
+        Returns a handle to ops that match the given names, types, or interface.
         If only a single type is given, the value wrapped by the resulting
         handle is populated with the respective type.
         """
@@ -90,13 +90,8 @@ def insert_transform_script(
     script: Callable[[ValueT], None],
     dump_script: bool = False,
 ) -> None:
-    """Inserts the transform script of the schedule into the module.
+    """Inserts the transform script of the schedule into the module."""
 
-    Args:
-        module: Existing module into which the script should be inserted.
-        script: The transform script to apply at.
-        dump_script: Whether to dump the script after creation.
-    """
     # Insert the script into the IR
     with module.context, ir.Location.unknown(module.context):
         with ir.InsertionPoint.at_block_begin(module.body):

>From ee8f9c2a8429aaa6d0e16fd786d5943ce9e12227 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:24:15 +0100
Subject: [PATCH 3/5] add in-depth documentation to `insert_transform_script`

---
 .../dialects/transform/extras/__init__.py     | 30 +++++++++++++++----
 mlir/test/python/dialects/transform_extras.py |  8 ++---
 2 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 4b864cf42a53b6..8e6788923d796e 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -82,17 +82,35 @@ def match_ops(
         return handle
 
 
-ValueT = TypeVar("ValueT", bound=Value)
-
-
 def insert_transform_script(
     module: ir.Module,
-    script: Callable[[ValueT], None],
+    script: Callable[[OpHandle], None],
     dump_script: bool = False,
 ) -> None:
-    """Inserts the transform script of the schedule into the module."""
+    """
+    Inserts the transform script of the schedule into the module. The script
+    should accept an instance of OpHandle as argument, which will be called with
+    the block arg of the newly created sequence op.
+
+    Example:
+    This python code
+    ```
+    module = ir.Module.create()
+    def test_match_ops_single(module: OpHandle):
+        module.match_ops(scf.ForOp)
+    insert_transform_script(module, script)
+    ```
+    generates the following IR:
+    ```
+    module {
+        transform.sequence failures(propagate) {
+        ^bb0(%arg0: !transform.any_op):
+            %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
+        }
+    }
+    ```
+    """
 
-    # Insert the script into the IR
     with module.context, ir.Location.unknown(module.context):
         with ir.InsertionPoint.at_block_begin(module.body):
             sequence_op = transform.SequenceOp(
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index 8c482018868319..08d853d8c2bc77 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -1,15 +1,13 @@
 # RUN: %PYTHON %s | FileCheck %s
 
-from typing import Callable, TypeVar
+from typing import Callable
 from mlir import ir
 from mlir.dialects import scf
 from mlir.dialects.transform import structured
-from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
 
-ValueT = TypeVar("ValueT", bound=Value)
 
-
-def build_transform_script(script: Callable[[ValueT], None]):
+def build_transform_script(script: Callable[[OpHandle], None]):
     print("\nTEST:", script.__name__)
     with ir.Context(), ir.Location.unknown():
         module = ir.Module.create()

>From 7ee6bb0245b46bd6011e17bb25d4635ed2bea589 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:30:09 +0100
Subject: [PATCH 4/5] remove unneeded import

---
 mlir/python/mlir/dialects/transform/extras/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8e6788923d796e..9f1f752bd7dba4 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,7 +2,7 @@
 
 import abc
 from dataclasses import dataclass, field
-from typing import Callable, Optional, Sequence, TypeVar
+from typing import Callable, Optional, Sequence
 
 try:
     from .... import ir

>From fd1ddd1bca4d43c49862274c12981b36e7a1ffe8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Wed, 13 Dec 2023 15:42:33 +0100
Subject: [PATCH 5/5] Address PR comments - expose `get_static_typeid` for
 Python mlir type subclasses - expose typeid getters for transform dialect
 types to CAPI - sequence -> named_sequence - more general
 `insert_transform_script` - use `value_caster` automation instead of
 composition - clarifying documentation - cleaning up imports - add missing
 license header

---
 mlir/include/mlir-c/Dialect/Transform.h       |   8 +
 .../mlir/Bindings/Python/PybindAdaptors.h     |   2 +
 mlir/lib/Bindings/Python/DialectTransform.cpp |  12 +-
 mlir/lib/CAPI/Dialect/Transform.cpp           |  18 ++-
 mlir/python/CMakeLists.txt                    |  16 +-
 .../dialects/transform/extras/__init__.py     | 126 ---------------
 .../extras/dialects/transform/__init__.py     | 147 ++++++++++++++++++
 mlir/test/python/dialects/transform_extras.py |  45 ++++--
 8 files changed, 221 insertions(+), 153 deletions(-)
 delete mode 100644 mlir/python/mlir/dialects/transform/extras/__init__.py
 create mode 100644 mlir/python/mlir/extras/dialects/transform/__init__.py

diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 91c99b1f869f22..02c99b59218825 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -25,6 +25,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
 
 //===---------------------------------------------------------------------===//
@@ -33,6 +35,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
 
 //===---------------------------------------------------------------------===//
@@ -41,6 +45,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx);
 
 //===---------------------------------------------------------------------===//
@@ -63,6 +69,8 @@ mlirTransformOperationTypeGetOperationName(MlirType type);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
                                                       MlirType type);
 
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 5e0e56fc00a673..125f9b92937bff 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -495,6 +495,8 @@ class mlir_type_subclass : public pure_subclass {
           .attr("replace")(superCls.attr("__name__"), captureTypeName);
     });
     if (getTypeIDFunction) {
+      def_staticmethod("get_static_typeid",
+                       [getTypeIDFunction]() { return getTypeIDFunction(); });
       py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
           .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
               getTypeIDFunction())(pybind11::cpp_function(
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index c7764f4e7aecac..6b57e652aa9d8b 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -27,7 +27,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
   //===-------------------------------------------------------------------===//
 
   auto anyOpType =
-      mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
+      mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
+                         mlirTransformAnyOpTypeGetTypeID);
   anyOpType.def_classmethod(
       "get",
       [](py::object cls, MlirContext ctx) {
@@ -41,7 +42,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
   //===-------------------------------------------------------------------===//
 
   auto anyParamType =
-      mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
+      mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
+                         mlirTransformAnyParamTypeGetTypeID);
   anyParamType.def_classmethod(
       "get",
       [](py::object cls, MlirContext ctx) {
@@ -55,7 +57,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
   //===-------------------------------------------------------------------===//
 
   auto anyValueType =
-      mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType);
+      mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
+                         mlirTransformAnyValueTypeGetTypeID);
   anyValueType.def_classmethod(
       "get",
       [](py::object cls, MlirContext ctx) {
@@ -96,7 +99,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
   //===-------------------------------------------------------------------===//
 
   auto paramType =
-      mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
+      mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
+                         mlirTransformParamTypeGetTypeID);
   paramType.def_classmethod(
       "get",
       [](py::object cls, MlirType type, MlirContext ctx) {
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 3f7f8b8e2113fe..5fd773572bd3c8 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -25,6 +25,10 @@ bool mlirTypeIsATransformAnyOpType(MlirType type) {
   return isa<transform::AnyOpType>(unwrap(type));
 }
 
+MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) {
+  return wrap(transform::AnyOpType::getTypeID());
+}
+
 MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
   return wrap(transform::AnyOpType::get(unwrap(ctx)));
 }
@@ -37,6 +41,10 @@ bool mlirTypeIsATransformAnyParamType(MlirType type) {
   return isa<transform::AnyParamType>(unwrap(type));
 }
 
+MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) {
+  return wrap(transform::AnyParamType::getTypeID());
+}
+
 MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
   return wrap(transform::AnyParamType::get(unwrap(ctx)));
 }
@@ -49,6 +57,10 @@ bool mlirTypeIsATransformAnyValueType(MlirType type) {
   return isa<transform::AnyValueType>(unwrap(type));
 }
 
+MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) {
+  return wrap(transform::AnyValueType::getTypeID());
+}
+
 MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
   return wrap(transform::AnyValueType::get(unwrap(ctx)));
 }
@@ -76,13 +88,17 @@ MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
 }
 
 //===---------------------------------------------------------------------===//
-// AnyOpType
+// ParamType
 //===---------------------------------------------------------------------===//
 
 bool mlirTypeIsATransformParamType(MlirType type) {
   return isa<transform::ParamType>(unwrap(type));
 }
 
+MlirTypeID mlirTransformParamTypeGetTypeID(void) {
+  return wrap(transform::ParamType::getTypeID());
+}
+
 MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
   return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
 }
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8013b49dbf9d68..41d91cf6778338 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -166,6 +166,14 @@ declare_mlir_dialect_python_bindings(
     "../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
 )
 
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.transform.extras
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  GEN_ENUM_BINDINGS
+  SOURCES
+    extras/dialects/transform/__init__.py)
+
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -311,14 +319,6 @@ declare_mlir_dialect_python_bindings(
     dialects/rocdl.py
   DIALECT_NAME rocdl)
 
-declare_mlir_python_sources(
-  MLIRPythonSources.Dialects.transform.extras
-  ADD_TO_PARENT MLIRPythonSources.Dialects
-  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
-  GEN_ENUM_BINDINGS
-  SOURCES
-    dialects/transform/extras/__init__.py)
-
 declare_mlir_python_sources(
   MLIRPythonSources.Dialects.quant
   ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
deleted file mode 100644
index 9f1f752bd7dba4..00000000000000
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ /dev/null
@@ -1,126 +0,0 @@
-from __future__ import annotations
-
-import abc
-from dataclasses import dataclass, field
-from typing import Callable, Optional, Sequence
-
-try:
-    from .... import ir
-    from ....dialects import transform
-    from ....dialects.transform import structured
-except ImportError as e:
-    raise RuntimeError("Error loading imports") from e
-
-
- at dataclass
-class Value(abc.ABC):
-    """Wrapper around a transform value handle with methods to chain further transforms."""
-
-    _mlir_value: ir.Value
-    children: list[Value] = field(default_factory=list)
-    parent: Optional[Value] = None
-
-    @property
-    def mlir_value(self) -> ir.Value:
-        return self._mlir_value
-
-
- at dataclass
-class Param(Value):
-    """Wrapper around a transform Param with methods to chain further transforms."""
-
-
- at dataclass
-class OpHandle(Value):
-    """Wrapper around a transform OpHandle with methods to chain further transforms."""
-
-    def match_ops(
-        self,
-        ops: str
-        | ir.OpView
-        | structured.MatchInterfaceEnum
-        | Sequence[str | ir.OpView],
-    ) -> OpHandle:
-        """
-        Returns a handle to ops that match the given names, types, or interface.
-        If only a single type is given, the value wrapped by the resulting
-        handle is populated with the respective type.
-        """
-        # Handle interface.
-        if isinstance(ops, structured.MatchInterfaceEnum) or (
-            isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
-        ):
-            if isinstance(ops, str):
-                ops = structured.MatchInterfaceEnum[ops]
-            match_op = structured.MatchOp(
-                transform.AnyOpType.get(),
-                self.mlir_value,
-                interface=ops,
-            )
-
-        # Handle op name(s), either given directly as string or given as op.
-        else:
-            if isinstance(ops, str):
-                op_type = transform.OperationType.get(ops)
-                op_names = [ops]
-            elif isinstance(ops, Sequence):
-                op_type = transform.AnyOpType.get()
-                op_names = [
-                    op if isinstance(op, str) else op.OPERATION_NAME for op in ops
-                ]
-            else:
-                op_type = transform.OperationType.get(ops.OPERATION_NAME)
-                op_names = [ops.OPERATION_NAME]
-            match_op = structured.MatchOp.match_op_names(
-                op_type,
-                self.mlir_value,
-                op_names,
-            )
-
-        handle = OpHandle(match_op.results_, parent=self)
-        self.children.append(handle)
-        return handle
-
-
-def insert_transform_script(
-    module: ir.Module,
-    script: Callable[[OpHandle], None],
-    dump_script: bool = False,
-) -> None:
-    """
-    Inserts the transform script of the schedule into the module. The script
-    should accept an instance of OpHandle as argument, which will be called with
-    the block arg of the newly created sequence op.
-
-    Example:
-    This python code
-    ```
-    module = ir.Module.create()
-    def test_match_ops_single(module: OpHandle):
-        module.match_ops(scf.ForOp)
-    insert_transform_script(module, script)
-    ```
-    generates the following IR:
-    ```
-    module {
-        transform.sequence failures(propagate) {
-        ^bb0(%arg0: !transform.any_op):
-            %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
-        }
-    }
-    ```
-    """
-
-    with module.context, ir.Location.unknown(module.context):
-        with ir.InsertionPoint.at_block_begin(module.body):
-            sequence_op = transform.SequenceOp(
-                transform.FailurePropagationMode.Propagate,
-                (),
-                transform.AnyOpType.get(),
-            )
-        with ir.InsertionPoint(sequence_op.body):
-            script(OpHandle(sequence_op.bodyTarget))
-            transform.YieldOp([])
-
-    if dump_script:
-        print(sequence_op)
diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/extras/dialects/transform/__init__.py
new file mode 100644
index 00000000000000..5870fa4fb1661d
--- /dev/null
+++ b/mlir/python/mlir/extras/dialects/transform/__init__.py
@@ -0,0 +1,147 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from __future__ import annotations
+from typing import Callable, Optional, Sequence
+
+from .... import ir
+from ....dialects import transform
+from ....dialects.transform import structured
+
+
+class Handle(ir.Value):
+    """
+    Base class for wrappers around different types of transform handle with
+    methods to chain further transforms.
+
+    The fields `children` and `parent` are used to capture the relation of
+    handles statically in order to enable further analysis. The payload
+    operation of a child handle is nested into a region of the payload operation
+    of the corresponding parent handle.
+    """
+
+    def __init__(
+        self,
+        v: ir.Value,
+        *,
+        parent: Optional[Handle] = None,
+        children: Optional[Sequence[Handle]] = None,
+    ):
+        super().__init__(v)
+        self.parent = parent
+        self.children = children if children is not None else []
+
+
+ at ir.register_value_caster(transform.AnyOpType.get_static_typeid())
+ at ir.register_value_caster(transform.OperationType.get_static_typeid())
+class OpHandle(Handle):
+    """
+    Wrapper around a transform operation handle with methods to chain further
+    transforms.
+    """
+
+    def __init__(
+        self,
+        v: ir.Value,
+        *,
+        parent: Optional[Handle] = None,
+        children: Optional[Sequence[Handle]] = None,
+    ):
+        super().__init__(v, parent=parent, children=children)
+
+    def match_ops(
+        self,
+        ops: str
+        | ir.OpView
+        | structured.MatchInterfaceEnum
+        | Sequence[str | ir.OpView],
+    ) -> OpHandle:
+        """
+        Returns a handle to ops that match the given names, types, or interface.
+        If only a single type is given, the value wrapped by the resulting
+        handle is populated with the respective type.
+        """
+        # Handle interface.
+        if isinstance(ops, structured.MatchInterfaceEnum) or (
+            isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+        ):
+            if isinstance(ops, str):
+                ops = structured.MatchInterfaceEnum[ops]
+            match_op = structured.MatchOp(
+                transform.AnyOpType.get(),
+                self,
+                interface=ops,
+            )
+
+        # Handle op name(s), either given directly as string or given as op.
+        else:
+            if isinstance(ops, str):
+                op_type = transform.OperationType.get(ops)
+                op_names = [ops]
+            elif isinstance(ops, Sequence):
+                op_type = transform.AnyOpType.get()
+                op_names = [
+                    op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+                ]
+            else:
+                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_names = [ops.OPERATION_NAME]
+            match_op = structured.MatchOp.match_op_names(
+                op_type,
+                self,
+                op_names,
+            )
+
+        handle = OpHandle(match_op.results_, parent=self)
+        self.children.append(handle)
+        return handle
+
+
+def insert_transform_script(
+    block_or_insertion_point: ir.Block | ir.InsertionPoint,
+    script: Callable[[OpHandle], None],
+    dump_script: bool = False,
+) -> None:
+    """
+    Inserts the transform script of the schedule into the module. The script
+    should accept an instance of OpHandle as argument, which will be called with
+    the block arg of the newly created named_sequence op.
+
+    Example:
+    This python code
+    ```
+    module = ir.Module.create()
+    def test_match_ops_single(module: OpHandle):
+        module.match_ops(scf.ForOp)
+    insert_transform_script(module.body, script)
+    ```
+    generates the following IR:
+    ```
+    module {
+        transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+        ^bb0(%arg0: !transform.any_op):
+            %0 = transform.structured.match ops{["scf.for"]} in %arg0
+                 : (!transform.any_op) -> !transform.op<"scf.for">
+        }
+    }
+    ```
+    """
+    if isinstance(block_or_insertion_point, ir.Block):
+        context = block_or_insertion_point.owner.context
+        insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point)
+    else:
+        context = block_or_insertion_point.block.owner.context
+        insertion_point = block_or_insertion_point
+
+    with context, ir.Location.unknown(context):
+        with insertion_point:
+            named_sequence_op = transform.NamedSequenceOp(
+                "__transform_main", [transform.AnyOpType.get()], []
+            )
+        with ir.InsertionPoint(named_sequence_op.body):
+            script(named_sequence_op.bodyTarget)
+            transform.YieldOp([])
+
+    if dump_script:
+        print(named_sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index 08d853d8c2bc77..dbfa8a2dc73c41 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -4,23 +4,45 @@
 from mlir import ir
 from mlir.dialects import scf
 from mlir.dialects.transform import structured
-from mlir.dialects.transform.extras import OpHandle, insert_transform_script
+from mlir.extras.dialects.transform import OpHandle, insert_transform_script
 
 
 def build_transform_script(script: Callable[[OpHandle], None]):
     print("\nTEST:", script.__name__)
     with ir.Context(), ir.Location.unknown():
         module = ir.Module.create()
-        insert_transform_script(module, script=script, dump_script=True)
+        module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+        insert_transform_script(module.body, script=script, dump_script=True)
         module.operation.verify()
 
 
+def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]):
+    print("\nTEST:", script.__name__)
+    with ir.Context(), ir.Location.unknown():
+        module = ir.Module.create()
+        module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+        insert_transform_script(
+            ir.InsertionPoint.at_block_begin(module.body),
+            script=script,
+            dump_script=True,
+        )
+        module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_build_script_at_insertion_point
+ at build_transform_script_at_insertion_point
+def test_build_script_at_insertion_point(op: OpHandle):
+    pass
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: transform.yield
+    # CHECK-NEXT: }
+
+
 # CHECK-LABEL: TEST: test_match_ops_single
 @build_transform_script
 def test_match_ops_single(op: OpHandle):
     op.match_ops(scf.ForOp)
-    # CHECK: transform.sequence
-    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
     # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
     # CHECK-SAME:    in %[[VAL_0]]
     # CHECK-SAME:      -> !transform.op<"scf.for">
@@ -30,8 +52,7 @@ def test_match_ops_single(op: OpHandle):
 @build_transform_script
 def test_match_ops_string_name(op: OpHandle):
     op.match_ops("linalg.matmul")
-    # CHECK: transform.sequence
-    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
     # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
     # CHECK-SAME:   ops{["linalg.matmul"]} in %[[VAL_0]]
 
@@ -40,8 +61,7 @@ def test_match_ops_string_name(op: OpHandle):
 @build_transform_script
 def test_match_ops_string_iface(op: OpHandle):
     op.match_ops("LinalgOp")
-    # CHECK: transform.sequence
-    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
     # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
     # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
 
@@ -50,8 +70,7 @@ def test_match_ops_string_iface(op: OpHandle):
 @build_transform_script
 def test_match_ops_iface(op: OpHandle):
     op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
-    # CHECK: transform.sequence
-    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
     # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
     # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
 
@@ -60,8 +79,7 @@ def test_match_ops_iface(op: OpHandle):
 @build_transform_script
 def test_match_ops_multiple(op: OpHandle):
     op.match_ops([scf.ForOp, scf.ForallOp])
-    # CHECK: transform.sequence
-    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
     # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
     # CHECK-SAME:   ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
     # CHECK-SAME:     -> !transform.any_op
@@ -71,8 +89,7 @@ def test_match_ops_multiple(op: OpHandle):
 @build_transform_script
 def test_match_ops_mixed(op: OpHandle):
     op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
-    # CHECK: transform.sequence
-    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
     # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
     # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
     # CHECK-SAME:     -> !transform.any_op



More information about the Mlir-commits mailing list