[Mlir-commits] [mlir] [MLIR][transform][python] add sugared python abstractions for transform dialect (PR #75073)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 09:02:40 PST 2023


Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>
Message-ID: <llvm.org/llvm/llvm-project/pull/75073 at github.com>
In-Reply-To:


https://github.com/martin-luecke created https://github.com/llvm/llvm-project/pull/75073

This adds Python abstractions for the different handle types of the transform dialect 

The abstractions allow for straightforward chaining of transforms by calling their member functions.
As an initial PR for this infrastructure, only a single transform is included: `transform.structured.match`. 
With a future `tile` transform abstraction an example of the usage is: 
```Python
def script(module: OpHandle):
    module.match_ops(MatchInterfaceEnum.TilingInterface).tile(tile_sizes=[32,32])
```
to generate the following IR:
```mlir
%0 = transform.structured.match interface{TilingInterface} in %arg0
%tiled_op, %loops = transform.structured.tile_using_for %0 [32, 32]
```

These abstractions are intended to enhance the usability and flexibility of the transform dialect by providing an accessible interface that allows for easy assembly of complex transformation chains.

>From e02d2b42e5595ee9abdc297648fc35165a92c967 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Wed, 6 Dec 2023 16:32:56 +0100
Subject: [PATCH 1/4] [MLIR][transform][python] add extras for sugared python
 abstractions

---
 mlir/python/CMakeLists.txt                    |   8 ++
 .../dialects/transform/extras/__init__.py     | 113 ++++++++++++++++++
 mlir/test/python/dialects/transform_extras.py |  80 +++++++++++++
 3 files changed, 201 insertions(+)
 create mode 100644 mlir/python/mlir/dialects/transform/extras/__init__.py
 create mode 100644 mlir/test/python/dialects/transform_extras.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 585918afc2633..8013b49dbf9d6 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -311,6 +311,14 @@ declare_mlir_dialect_python_bindings(
     dialects/rocdl.py
   DIALECT_NAME rocdl)
 
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.transform.extras
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  GEN_ENUM_BINDINGS
+  SOURCES
+    dialects/transform/extras/__init__.py)
+
 declare_mlir_python_sources(
   MLIRPythonSources.Dialects.quant
   ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
new file mode 100644
index 0000000000000..0ccbb2e9254fd
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -0,0 +1,113 @@
+from __future__ import annotations
+
+import abc
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence, TypeVar
+
+try:
+    from .... import ir
+    from ....dialects import transform
+    from ....dialects.transform import structured
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+
+ at dataclass
+class Value(abc.ABC):
+    """Wrapper around a transform value handle with methods to chain further transforms."""
+
+    _mlir_value: ir.Value
+    children: list[Value] = field(default_factory=list)
+    parent: Optional[Value] = None
+
+    @property
+    def mlir_value(self) -> ir.Value:
+        return self._mlir_value
+
+
+ at dataclass
+class Param(Value):
+    """Wrapper around a transform Param with methods to chain further transforms."""
+
+
+ at dataclass
+class OpHandle(Value):
+    """Wrapper around a transform OpHandle with methods to chain further transforms."""
+
+    def match_ops(
+        self,
+        ops: str
+        | ir.OpView
+        | structured.MatchInterfaceEnum
+        | Sequence[str | ir.OpView],
+    ) -> OpHandle:
+        """Returns a handle to ops that match the given names, types, or interface.
+
+        If only a single type is given, the value wrapped by the resulting
+        handle is populated with the respective type.
+        """
+        # Handle interface.
+        if isinstance(ops, structured.MatchInterfaceEnum) or (
+            isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+        ):
+            if isinstance(ops, str):
+                ops = structured.MatchInterfaceEnum[ops]
+            match_op = structured.MatchOp(
+                transform.AnyOpType.get(),
+                self.mlir_value,
+                interface=ops,
+            )
+
+        # Handle op name(s), either given directly as string or given as op.
+        else:
+            if isinstance(ops, str):
+                op_type = transform.OperationType.get(ops)
+                op_names = [ops]
+            elif isinstance(ops, Sequence):
+                op_type = transform.AnyOpType.get()
+                op_names = [
+                    op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+                ]
+            else:
+                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_names = [ops.OPERATION_NAME]
+            match_op = structured.MatchOp.match_op_names(
+                op_type,
+                self.mlir_value,
+                op_names,
+            )
+
+        handle = OpHandle(match_op.results_, parent=self)
+        self.children.append(handle)
+        return handle
+
+
+ValueT = TypeVar("ValueT", bound=Value)
+
+
+def insert_transform_script(
+    module: ir.Module,
+    script: Callable[[ValueT], None],
+    dump_script: bool = False,
+) -> None:
+    """Inserts the transform script of the schedule into the module.
+
+    Args:
+        module: Existing module into which the script should be inserted.
+        script: The transform script to apply at.
+        dump_script: Whether to dump the script after creation.
+    """
+    # Insert the script into the IR
+    with module.context, ir.Location.unknown(module.context):
+        with ir.InsertionPoint.at_block_begin(module.body):
+            sequence_op = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                (),
+                transform.AnyOpType.get(),
+            )
+        with ir.InsertionPoint(sequence_op.body):
+            script(OpHandle(sequence_op.bodyTarget))
+            transform.YieldOp([])
+
+    if dump_script:
+        print(sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
new file mode 100644
index 0000000000000..8c48201886831
--- /dev/null
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -0,0 +1,80 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Callable, TypeVar
+from mlir import ir
+from mlir.dialects import scf
+from mlir.dialects.transform import structured
+from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
+
+ValueT = TypeVar("ValueT", bound=Value)
+
+
+def build_transform_script(script: Callable[[ValueT], None]):
+    print("\nTEST:", script.__name__)
+    with ir.Context(), ir.Location.unknown():
+        module = ir.Module.create()
+        insert_transform_script(module, script=script, dump_script=True)
+        module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_match_ops_single
+ at build_transform_script
+def test_match_ops_single(op: OpHandle):
+    op.match_ops(scf.ForOp)
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
+    # CHECK-SAME:    in %[[VAL_0]]
+    # CHECK-SAME:      -> !transform.op<"scf.for">
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_name
+ at build_transform_script
+def test_match_ops_string_name(op: OpHandle):
+    op.match_ops("linalg.matmul")
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["linalg.matmul"]} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_iface
+ at build_transform_script
+def test_match_ops_string_iface(op: OpHandle):
+    op.match_ops("LinalgOp")
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_iface
+ at build_transform_script
+def test_match_ops_iface(op: OpHandle):
+    op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_multiple
+ at build_transform_script
+def test_match_ops_multiple(op: OpHandle):
+    op.match_ops([scf.ForOp, scf.ForallOp])
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
+    # CHECK-SAME:     -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_match_ops_mixed
+ at build_transform_script
+def test_match_ops_mixed(op: OpHandle):
+    op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
+    # CHECK: transform.sequence
+    # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
+    # CHECK-SAME:     -> !transform.any_op

>From 77d6011002379dd2ae47268ee60fdd5b502a1638 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:21:47 +0100
Subject: [PATCH 2/4] cleaning up comments

---
 .../mlir/dialects/transform/extras/__init__.py      | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 0ccbb2e9254fd..4b864cf42a53b 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -9,7 +9,7 @@
     from ....dialects import transform
     from ....dialects.transform import structured
 except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports") from e
 
 
 @dataclass
@@ -41,8 +41,8 @@ def match_ops(
         | structured.MatchInterfaceEnum
         | Sequence[str | ir.OpView],
     ) -> OpHandle:
-        """Returns a handle to ops that match the given names, types, or interface.
-
+        """
+        Returns a handle to ops that match the given names, types, or interface.
         If only a single type is given, the value wrapped by the resulting
         handle is populated with the respective type.
         """
@@ -90,13 +90,8 @@ def insert_transform_script(
     script: Callable[[ValueT], None],
     dump_script: bool = False,
 ) -> None:
-    """Inserts the transform script of the schedule into the module.
+    """Inserts the transform script of the schedule into the module."""
 
-    Args:
-        module: Existing module into which the script should be inserted.
-        script: The transform script to apply at.
-        dump_script: Whether to dump the script after creation.
-    """
     # Insert the script into the IR
     with module.context, ir.Location.unknown(module.context):
         with ir.InsertionPoint.at_block_begin(module.body):

>From df9513a77d7015400b8c0c48745fc940658699cc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:24:15 +0100
Subject: [PATCH 3/4] add in-depth documentation to `insert_transform_script`

---
 .../dialects/transform/extras/__init__.py     | 30 +++++++++++++++----
 mlir/test/python/dialects/transform_extras.py |  8 ++---
 2 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 4b864cf42a53b..8e6788923d796 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -82,17 +82,35 @@ def match_ops(
         return handle
 
 
-ValueT = TypeVar("ValueT", bound=Value)
-
-
 def insert_transform_script(
     module: ir.Module,
-    script: Callable[[ValueT], None],
+    script: Callable[[OpHandle], None],
     dump_script: bool = False,
 ) -> None:
-    """Inserts the transform script of the schedule into the module."""
+    """
+    Inserts the transform script of the schedule into the module. The script
+    should accept an instance of OpHandle as argument, which will be called with
+    the block arg of the newly created sequence op.
+
+    Example:
+    This python code
+    ```
+    module = ir.Module.create()
+    def test_match_ops_single(module: OpHandle):
+        module.match_ops(scf.ForOp)
+    insert_transform_script(module, script)
+    ```
+    generates the following IR:
+    ```
+    module {
+        transform.sequence failures(propagate) {
+        ^bb0(%arg0: !transform.any_op):
+            %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
+        }
+    }
+    ```
+    """
 
-    # Insert the script into the IR
     with module.context, ir.Location.unknown(module.context):
         with ir.InsertionPoint.at_block_begin(module.body):
             sequence_op = transform.SequenceOp(
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index 8c48201886831..08d853d8c2bc7 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -1,15 +1,13 @@
 # RUN: %PYTHON %s | FileCheck %s
 
-from typing import Callable, TypeVar
+from typing import Callable
 from mlir import ir
 from mlir.dialects import scf
 from mlir.dialects.transform import structured
-from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
 
-ValueT = TypeVar("ValueT", bound=Value)
 
-
-def build_transform_script(script: Callable[[ValueT], None]):
+def build_transform_script(script: Callable[[OpHandle], None]):
     print("\nTEST:", script.__name__)
     with ir.Context(), ir.Location.unknown():
         module = ir.Module.create()

>From 9bce456cc36130b1267f6d66dd4f0d460a413b6c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:30:09 +0100
Subject: [PATCH 4/4] remove unneeded import

---
 mlir/python/mlir/dialects/transform/extras/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8e6788923d796..9f1f752bd7dba 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,7 +2,7 @@
 
 import abc
 from dataclasses import dataclass, field
-from typing import Callable, Optional, Sequence, TypeVar
+from typing import Callable, Optional, Sequence
 
 try:
     from .... import ir



More information about the Mlir-commits mailing list