[Mlir-commits] [mlir] [MLIR][transform][python] add sugared python abstractions for transform dialect (PR #75073)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 13 05:56:08 PST 2023


Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/75073 at github.com>


https://github.com/martin-luecke updated https://github.com/llvm/llvm-project/pull/75073

>From 5ec81e059c9c1a2be91782beca501a5e0d5df474 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Wed, 6 Dec 2023 16:32:56 +0100
Subject: [PATCH 1/4] [MLIR][transform][python] add extras for sugared python
 abstractions

---
 mlir/python/CMakeLists.txt                    |   8 ++
 .../dialects/transform/extras/__init__.py     | 113 ++++++++++++++++++
 mlir/test/python/dialects/transform_extras.py |  80 +++++++++++++
 3 files changed, 201 insertions(+)
 create mode 100644 mlir/python/mlir/dialects/transform/extras/__init__.py
 create mode 100644 mlir/test/python/dialects/transform_extras.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 585918afc26335..8013b49dbf9d68 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -311,6 +311,14 @@ declare_mlir_dialect_python_bindings(
     dialects/rocdl.py
   DIALECT_NAME rocdl)
 
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.transform.extras
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  GEN_ENUM_BINDINGS
+  SOURCES
+    dialects/transform/extras/__init__.py)
+
 declare_mlir_python_sources(
   MLIRPythonSources.Dialects.quant
   ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
new file mode 100644
index 00000000000000..0ccbb2e9254fd2
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -0,0 +1,113 @@
+from __future__ import annotations
+
+import abc
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence, TypeVar
+
+try:
+    from .... import ir
+    from ....dialects import transform
+    from ....dialects.transform import structured
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+
+ at dataclass
+class Value(abc.ABC):
+    """Wrapper around a transform value handle with methods to chain further transforms."""
+
+    _mlir_value: ir.Value
+    children: list[Value] = field(default_factory=list)
+    parent: Optional[Value] = None
+
+    @property
+    def mlir_value(self) -> ir.Value:
+        return self._mlir_value
+
+
+ at dataclass
+class Param(Value):
+    """Wrapper around a transform Param with methods to chain further transforms."""
+
+
+ at dataclass
+class OpHandle(Value):
+    """Wrapper around a transform OpHandle with methods to chain further transforms."""
+
+    def match_ops(
+        self,
+        ops: str
+        | ir.OpView
+        | structured.MatchInterfaceEnum
+        | Sequence[str | ir.OpView],
+    ) -> OpHandle:
+        """Returns a handle to ops that match the given names, types, or interface.
+
+        If only a single type is given, the value wrapped by the resulting
+        handle is populated with the respective type.
+        """
+        # Handle interface.
+        if isinstance(ops, structured.MatchInterfaceEnum) or (
+            isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+        ):
+            if isinstance(ops, str):
+                ops = structured.MatchInterfaceEnum[ops]
+            match_op = structured.MatchOp(
+                transform.AnyOpType.get(),
+                self.mlir_value,
+                interface=ops,
+            )
+
+        # Handle op name(s), either given directly as string or given as op.
+        else:
+            if isinstance(ops, str):
+                op_type = transform.OperationType.get(ops)
+                op_names = [ops]
+            elif isinstance(ops, Sequence):
+                op_type = transform.AnyOpType.get()
+                op_names = [
+                    op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+                ]
+            else:
+                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_names = [ops.OPERATION_NAME]
+            match_op = structured.MatchOp.match_op_names(
+                op_type,
+                self.mlir_value,
+                op_names,
+            )
+
+        handle = OpHandle(match_op.results_, parent=self)
+        self.children.append(handle)
+        return handle
+
+
+ValueT = TypeVar("ValueT", bound=Value)
+
+
+def insert_transform_script(
+    module: ir.Module,
+    script: Callable[[ValueT], None],
+    dump_script: bool = False,
+) -> None:
+    """Inserts the transform script of the schedule into the module.
+
+    Args:
+        module: Existing module into which the script should be inserted.
+        script: The transform script to apply at.
+        dump_script: Whether to dump the script after creation.
+    """
+    # Insert the script into the IR
+    with module.context, ir.Location.unknown(module.context):
+        with ir.InsertionPoint.at_block_begin(module.body):
+            sequence_op = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                (),
+                transform.AnyOpType.get(),
+            )
+        with ir.InsertionPoint(sequence_op.body):
+            script(OpHandle(sequence_op.bodyTarget))
+            transform.YieldOp([])
+
+    if dump_script:
+        print(sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
new file mode 100644
index 00000000000000..8c482018868319
--- /dev/null
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -0,0 +1,80 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Callable, TypeVar
+from mlir import ir
+from mlir.dialects import scf
+from mlir.dialects.transform import structured
+from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
+
+ValueT = TypeVar("ValueT", bound=Value)
+
+
+def build_transform_script(script: Callable[[ValueT], None]):
+    print("\nTEST:", script.__name__)
+    with ir.Context(), ir.Location.unknown():
+        module = ir.Module.create()
+        insert_transform_script(module, script=script, dump_script=True)
+        module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_match_ops_single
+ at build_transform_script
+def test_match_ops_single(op: OpHandle):
+    op.match_ops(scf.ForOp)
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
+    # CHECK-SAME:    in %[[VAL_0]]
+    # CHECK-SAME:      -> !transform.op<"scf.for">
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_name
+ at build_transform_script
+def test_match_ops_string_name(op: OpHandle):
+    op.match_ops("linalg.matmul")
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["linalg.matmul"]} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_iface
+ at build_transform_script
+def test_match_ops_string_iface(op: OpHandle):
+    op.match_ops("LinalgOp")
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_iface
+ at build_transform_script
+def test_match_ops_iface(op: OpHandle):
+    op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_multiple
+ at build_transform_script
+def test_match_ops_multiple(op: OpHandle):
+    op.match_ops([scf.ForOp, scf.ForallOp])
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
+    # CHECK-SAME:     -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_match_ops_mixed
+ at build_transform_script
+def test_match_ops_mixed(op: OpHandle):
+    op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
+    # CHECK-SAME:     -> !transform.any_op

>From 86fb474f795338fc667f34324debac679b924457 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:21:47 +0100
Subject: [PATCH 2/4] cleaning up comments

---
 .../mlir/dialects/transform/extras/__init__.py      | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 0ccbb2e9254fd2..4b864cf42a53b6 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -9,7 +9,7 @@
     from ....dialects import transform
     from ....dialects.transform import structured
 except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports") from e
 
 
 @dataclass
@@ -41,8 +41,8 @@ def match_ops(
         | structured.MatchInterfaceEnum
         | Sequence[str | ir.OpView],
     ) -> OpHandle:
-        """Returns a handle to ops that match the given names, types, or interface.
-
+        """
+        Returns a handle to ops that match the given names, types, or interface.
         If only a single type is given, the value wrapped by the resulting
         handle is populated with the respective type.
         """
@@ -90,13 +90,8 @@ def insert_transform_script(
     script: Callable[[ValueT], None],
     dump_script: bool = False,
 ) -> None:
-    """Inserts the transform script of the schedule into the module.
+    """Inserts the transform script of the schedule into the module."""
 
-    Args:
-        module: Existing module into which the script should be inserted.
-        script: The transform script to apply at.
-        dump_script: Whether to dump the script after creation.
-    """
     # Insert the script into the IR
     with module.context, ir.Location.unknown(module.context):
         with ir.InsertionPoint.at_block_begin(module.body):

>From ee8f9c2a8429aaa6d0e16fd786d5943ce9e12227 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:24:15 +0100
Subject: [PATCH 3/4] add in-depth documentation to `insert_transform_script`

---
 .../dialects/transform/extras/__init__.py     | 30 +++++++++++++++----
 mlir/test/python/dialects/transform_extras.py |  8 ++---
 2 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 4b864cf42a53b6..8e6788923d796e 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -82,17 +82,35 @@ def match_ops(
         return handle
 
 
-ValueT = TypeVar("ValueT", bound=Value)
-
-
 def insert_transform_script(
     module: ir.Module,
-    script: Callable[[ValueT], None],
+    script: Callable[[OpHandle], None],
     dump_script: bool = False,
 ) -> None:
-    """Inserts the transform script of the schedule into the module."""
+    """
+    Inserts the transform script of the schedule into the module. The script
+    should accept an instance of OpHandle as argument, which will be called with
+    the block arg of the newly created sequence op.
+
+    Example:
+    This python code
+    ```
+    module = ir.Module.create()
+    def test_match_ops_single(module: OpHandle):
+        module.match_ops(scf.ForOp)
+    insert_transform_script(module, script)
+    ```
+    generates the following IR:
+    ```
+    module {
+        transform.sequence failures(propagate) {
+        ^bb0(%arg0: !transform.any_op):
+            %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
+        }
+    }
+    ```
+    """
 
-    # Insert the script into the IR
     with module.context, ir.Location.unknown(module.context):
         with ir.InsertionPoint.at_block_begin(module.body):
             sequence_op = transform.SequenceOp(
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index 8c482018868319..08d853d8c2bc77 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -1,15 +1,13 @@
 # RUN: %PYTHON %s | FileCheck %s
 
-from typing import Callable, TypeVar
+from typing import Callable
 from mlir import ir
 from mlir.dialects import scf
 from mlir.dialects.transform import structured
-from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
 
-ValueT = TypeVar("ValueT", bound=Value)
 
-
-def build_transform_script(script: Callable[[ValueT], None]):
+def build_transform_script(script: Callable[[OpHandle], None]):
     print("\nTEST:", script.__name__)
     with ir.Context(), ir.Location.unknown():
         module = ir.Module.create()

>From 7ee6bb0245b46bd6011e17bb25d4635ed2bea589 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:30:09 +0100
Subject: [PATCH 4/4] remove unneeded import

---
 mlir/python/mlir/dialects/transform/extras/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8e6788923d796e..9f1f752bd7dba4 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,7 +2,7 @@
 
 import abc
 from dataclasses import dataclass, field
-from typing import Callable, Optional, Sequence, TypeVar
+from typing import Callable, Optional, Sequence
 
 try:
     from .... import ir



More information about the Mlir-commits mailing list