[Mlir-commits] [mlir] [MLIR][transform][python] add sugared python abstractions for transform dialect (PR #75073)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 13 05:56:08 PST 2023
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/75073 at github.com>
https://github.com/martin-luecke updated https://github.com/llvm/llvm-project/pull/75073
>From 5ec81e059c9c1a2be91782beca501a5e0d5df474 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Wed, 6 Dec 2023 16:32:56 +0100
Subject: [PATCH 1/4] [MLIR][transform][python] add extras for sugared python
abstractions
---
mlir/python/CMakeLists.txt | 8 ++
.../dialects/transform/extras/__init__.py | 113 ++++++++++++++++++
mlir/test/python/dialects/transform_extras.py | 80 +++++++++++++
3 files changed, 201 insertions(+)
create mode 100644 mlir/python/mlir/dialects/transform/extras/__init__.py
create mode 100644 mlir/test/python/dialects/transform_extras.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 585918afc26335..8013b49dbf9d68 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -311,6 +311,14 @@ declare_mlir_dialect_python_bindings(
dialects/rocdl.py
DIALECT_NAME rocdl)
+declare_mlir_python_sources(
+ MLIRPythonSources.Dialects.transform.extras
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ GEN_ENUM_BINDINGS
+ SOURCES
+ dialects/transform/extras/__init__.py)
+
declare_mlir_python_sources(
MLIRPythonSources.Dialects.quant
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
new file mode 100644
index 00000000000000..0ccbb2e9254fd2
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -0,0 +1,113 @@
+from __future__ import annotations
+
+import abc
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence, TypeVar
+
+try:
+ from .... import ir
+ from ....dialects import transform
+ from ....dialects.transform import structured
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+ at dataclass
+class Value(abc.ABC):
+ """Wrapper around a transform value handle with methods to chain further transforms."""
+
+ _mlir_value: ir.Value
+ children: list[Value] = field(default_factory=list)
+ parent: Optional[Value] = None
+
+ @property
+ def mlir_value(self) -> ir.Value:
+ return self._mlir_value
+
+
+ at dataclass
+class Param(Value):
+ """Wrapper around a transform Param with methods to chain further transforms."""
+
+
+ at dataclass
+class OpHandle(Value):
+ """Wrapper around a transform OpHandle with methods to chain further transforms."""
+
+ def match_ops(
+ self,
+ ops: str
+ | ir.OpView
+ | structured.MatchInterfaceEnum
+ | Sequence[str | ir.OpView],
+ ) -> OpHandle:
+ """Returns a handle to ops that match the given names, types, or interface.
+
+ If only a single type is given, the value wrapped by the resulting
+ handle is populated with the respective type.
+ """
+ # Handle interface.
+ if isinstance(ops, structured.MatchInterfaceEnum) or (
+ isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+ ):
+ if isinstance(ops, str):
+ ops = structured.MatchInterfaceEnum[ops]
+ match_op = structured.MatchOp(
+ transform.AnyOpType.get(),
+ self.mlir_value,
+ interface=ops,
+ )
+
+ # Handle op name(s), either given directly as string or given as op.
+ else:
+ if isinstance(ops, str):
+ op_type = transform.OperationType.get(ops)
+ op_names = [ops]
+ elif isinstance(ops, Sequence):
+ op_type = transform.AnyOpType.get()
+ op_names = [
+ op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+ ]
+ else:
+ op_type = transform.OperationType.get(ops.OPERATION_NAME)
+ op_names = [ops.OPERATION_NAME]
+ match_op = structured.MatchOp.match_op_names(
+ op_type,
+ self.mlir_value,
+ op_names,
+ )
+
+ handle = OpHandle(match_op.results_, parent=self)
+ self.children.append(handle)
+ return handle
+
+
+ValueT = TypeVar("ValueT", bound=Value)
+
+
+def insert_transform_script(
+ module: ir.Module,
+ script: Callable[[ValueT], None],
+ dump_script: bool = False,
+) -> None:
+ """Inserts the transform script of the schedule into the module.
+
+ Args:
+ module: Existing module into which the script should be inserted.
+ script: The transform script to apply at.
+ dump_script: Whether to dump the script after creation.
+ """
+ # Insert the script into the IR
+ with module.context, ir.Location.unknown(module.context):
+ with ir.InsertionPoint.at_block_begin(module.body):
+ sequence_op = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ (),
+ transform.AnyOpType.get(),
+ )
+ with ir.InsertionPoint(sequence_op.body):
+ script(OpHandle(sequence_op.bodyTarget))
+ transform.YieldOp([])
+
+ if dump_script:
+ print(sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
new file mode 100644
index 00000000000000..8c482018868319
--- /dev/null
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -0,0 +1,80 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Callable, TypeVar
+from mlir import ir
+from mlir.dialects import scf
+from mlir.dialects.transform import structured
+from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
+
+ValueT = TypeVar("ValueT", bound=Value)
+
+
+def build_transform_script(script: Callable[[ValueT], None]):
+ print("\nTEST:", script.__name__)
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ insert_transform_script(module, script=script, dump_script=True)
+ module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_match_ops_single
+ at build_transform_script
+def test_match_ops_single(op: OpHandle):
+ op.match_ops(scf.ForOp)
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
+ # CHECK-SAME: in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.op<"scf.for">
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_name
+ at build_transform_script
+def test_match_ops_string_name(op: OpHandle):
+ op.match_ops("linalg.matmul")
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_iface
+ at build_transform_script
+def test_match_ops_string_iface(op: OpHandle):
+ op.match_ops("LinalgOp")
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_iface
+ at build_transform_script
+def test_match_ops_iface(op: OpHandle):
+ op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_multiple
+ at build_transform_script
+def test_match_ops_multiple(op: OpHandle):
+ op.match_ops([scf.ForOp, scf.ForallOp])
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_match_ops_mixed
+ at build_transform_script
+def test_match_ops_mixed(op: OpHandle):
+ op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.any_op
>From 86fb474f795338fc667f34324debac679b924457 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:21:47 +0100
Subject: [PATCH 2/4] cleaning up comments
---
.../mlir/dialects/transform/extras/__init__.py | 13 ++++---------
1 file changed, 4 insertions(+), 9 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 0ccbb2e9254fd2..4b864cf42a53b6 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -9,7 +9,7 @@
from ....dialects import transform
from ....dialects.transform import structured
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports") from e
@dataclass
@@ -41,8 +41,8 @@ def match_ops(
| structured.MatchInterfaceEnum
| Sequence[str | ir.OpView],
) -> OpHandle:
- """Returns a handle to ops that match the given names, types, or interface.
-
+ """
+ Returns a handle to ops that match the given names, types, or interface.
If only a single type is given, the value wrapped by the resulting
handle is populated with the respective type.
"""
@@ -90,13 +90,8 @@ def insert_transform_script(
script: Callable[[ValueT], None],
dump_script: bool = False,
) -> None:
- """Inserts the transform script of the schedule into the module.
+ """Inserts the transform script of the schedule into the module."""
- Args:
- module: Existing module into which the script should be inserted.
- script: The transform script to apply at.
- dump_script: Whether to dump the script after creation.
- """
# Insert the script into the IR
with module.context, ir.Location.unknown(module.context):
with ir.InsertionPoint.at_block_begin(module.body):
>From ee8f9c2a8429aaa6d0e16fd786d5943ce9e12227 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:24:15 +0100
Subject: [PATCH 3/4] add in-depth documentation to `insert_transform_script`
---
.../dialects/transform/extras/__init__.py | 30 +++++++++++++++----
mlir/test/python/dialects/transform_extras.py | 8 ++---
2 files changed, 27 insertions(+), 11 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 4b864cf42a53b6..8e6788923d796e 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -82,17 +82,35 @@ def match_ops(
return handle
-ValueT = TypeVar("ValueT", bound=Value)
-
-
def insert_transform_script(
module: ir.Module,
- script: Callable[[ValueT], None],
+ script: Callable[[OpHandle], None],
dump_script: bool = False,
) -> None:
- """Inserts the transform script of the schedule into the module."""
+ """
+ Inserts the transform script of the schedule into the module. The script
+ should accept an instance of OpHandle as argument, which will be called with
+ the block arg of the newly created sequence op.
+
+ Example:
+ This python code
+ ```
+ module = ir.Module.create()
+ def test_match_ops_single(module: OpHandle):
+ module.match_ops(scf.ForOp)
+ insert_transform_script(module, script)
+ ```
+ generates the following IR:
+ ```
+ module {
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
+ }
+ }
+ ```
+ """
- # Insert the script into the IR
with module.context, ir.Location.unknown(module.context):
with ir.InsertionPoint.at_block_begin(module.body):
sequence_op = transform.SequenceOp(
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index 8c482018868319..08d853d8c2bc77 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -1,15 +1,13 @@
# RUN: %PYTHON %s | FileCheck %s
-from typing import Callable, TypeVar
+from typing import Callable
from mlir import ir
from mlir.dialects import scf
from mlir.dialects.transform import structured
-from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
-ValueT = TypeVar("ValueT", bound=Value)
-
-def build_transform_script(script: Callable[[ValueT], None]):
+def build_transform_script(script: Callable[[OpHandle], None]):
print("\nTEST:", script.__name__)
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
>From 7ee6bb0245b46bd6011e17bb25d4635ed2bea589 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:30:09 +0100
Subject: [PATCH 4/4] remove unneeded import
---
mlir/python/mlir/dialects/transform/extras/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8e6788923d796e..9f1f752bd7dba4 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,7 +2,7 @@
import abc
from dataclasses import dataclass, field
-from typing import Callable, Optional, Sequence, TypeVar
+from typing import Callable, Optional, Sequence
try:
from .... import ir
More information about the Mlir-commits
mailing list