[Mlir-commits] [mlir] [MLIR][transform][python] add sugared python abstractions for transform dialect (PR #75073)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 11 09:03:10 PST 2023
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/75073 at github.com>
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (martin-luecke)
<details>
<summary>Changes</summary>
This adds Python abstractions for the different handle types of the transform dialect
The abstractions allow for straightforward chaining of transforms by calling their member functions.
As an initial PR for this infrastructure, only a single transform is included: `transform.structured.match`.
With a future `tile` transform abstraction an example of the usage is:
```Python
def script(module: OpHandle):
module.match_ops(MatchInterfaceEnum.TilingInterface).tile(tile_sizes=[32,32])
```
to generate the following IR:
```mlir
%0 = transform.structured.match interface{TilingInterface} in %arg0
%tiled_op, %loops = transform.structured.tile_using_for %0 [32, 32]
```
These abstractions are intended to enhance the usability and flexibility of the transform dialect by providing an accessible interface that allows for easy assembly of complex transformation chains.
---
Full diff: https://github.com/llvm/llvm-project/pull/75073.diff
3 Files Affected:
- (modified) mlir/python/CMakeLists.txt (+8)
- (added) mlir/python/mlir/dialects/transform/extras/__init__.py (+126)
- (added) mlir/test/python/dialects/transform_extras.py (+78)
``````````diff
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 585918afc2633..8013b49dbf9d6 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -311,6 +311,14 @@ declare_mlir_dialect_python_bindings(
dialects/rocdl.py
DIALECT_NAME rocdl)
+declare_mlir_python_sources(
+ MLIRPythonSources.Dialects.transform.extras
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ GEN_ENUM_BINDINGS
+ SOURCES
+ dialects/transform/extras/__init__.py)
+
declare_mlir_python_sources(
MLIRPythonSources.Dialects.quant
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
new file mode 100644
index 0000000000000..9f1f752bd7dba
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -0,0 +1,126 @@
+from __future__ import annotations
+
+import abc
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence
+
+try:
+ from .... import ir
+ from ....dialects import transform
+ from ....dialects.transform import structured
+except ImportError as e:
+ raise RuntimeError("Error loading imports") from e
+
+
+ at dataclass
+class Value(abc.ABC):
+ """Wrapper around a transform value handle with methods to chain further transforms."""
+
+ _mlir_value: ir.Value
+ children: list[Value] = field(default_factory=list)
+ parent: Optional[Value] = None
+
+ @property
+ def mlir_value(self) -> ir.Value:
+ return self._mlir_value
+
+
+ at dataclass
+class Param(Value):
+ """Wrapper around a transform Param with methods to chain further transforms."""
+
+
+ at dataclass
+class OpHandle(Value):
+ """Wrapper around a transform OpHandle with methods to chain further transforms."""
+
+ def match_ops(
+ self,
+ ops: str
+ | ir.OpView
+ | structured.MatchInterfaceEnum
+ | Sequence[str | ir.OpView],
+ ) -> OpHandle:
+ """
+ Returns a handle to ops that match the given names, types, or interface.
+ If only a single type is given, the value wrapped by the resulting
+ handle is populated with the respective type.
+ """
+ # Handle interface.
+ if isinstance(ops, structured.MatchInterfaceEnum) or (
+ isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+ ):
+ if isinstance(ops, str):
+ ops = structured.MatchInterfaceEnum[ops]
+ match_op = structured.MatchOp(
+ transform.AnyOpType.get(),
+ self.mlir_value,
+ interface=ops,
+ )
+
+ # Handle op name(s), either given directly as string or given as op.
+ else:
+ if isinstance(ops, str):
+ op_type = transform.OperationType.get(ops)
+ op_names = [ops]
+ elif isinstance(ops, Sequence):
+ op_type = transform.AnyOpType.get()
+ op_names = [
+ op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+ ]
+ else:
+ op_type = transform.OperationType.get(ops.OPERATION_NAME)
+ op_names = [ops.OPERATION_NAME]
+ match_op = structured.MatchOp.match_op_names(
+ op_type,
+ self.mlir_value,
+ op_names,
+ )
+
+ handle = OpHandle(match_op.results_, parent=self)
+ self.children.append(handle)
+ return handle
+
+
+def insert_transform_script(
+ module: ir.Module,
+ script: Callable[[OpHandle], None],
+ dump_script: bool = False,
+) -> None:
+ """
+ Inserts the transform script of the schedule into the module. The script
+ should accept an instance of OpHandle as argument, which will be called with
+ the block arg of the newly created sequence op.
+
+ Example:
+ This python code
+ ```
+ module = ir.Module.create()
+ def test_match_ops_single(module: OpHandle):
+ module.match_ops(scf.ForOp)
+ insert_transform_script(module, script)
+ ```
+ generates the following IR:
+ ```
+ module {
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
+ }
+ }
+ ```
+ """
+
+ with module.context, ir.Location.unknown(module.context):
+ with ir.InsertionPoint.at_block_begin(module.body):
+ sequence_op = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ (),
+ transform.AnyOpType.get(),
+ )
+ with ir.InsertionPoint(sequence_op.body):
+ script(OpHandle(sequence_op.bodyTarget))
+ transform.YieldOp([])
+
+ if dump_script:
+ print(sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
new file mode 100644
index 0000000000000..08d853d8c2bc7
--- /dev/null
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -0,0 +1,78 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Callable
+from mlir import ir
+from mlir.dialects import scf
+from mlir.dialects.transform import structured
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
+
+
+def build_transform_script(script: Callable[[OpHandle], None]):
+ print("\nTEST:", script.__name__)
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ insert_transform_script(module, script=script, dump_script=True)
+ module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_match_ops_single
+ at build_transform_script
+def test_match_ops_single(op: OpHandle):
+ op.match_ops(scf.ForOp)
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
+ # CHECK-SAME: in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.op<"scf.for">
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_name
+ at build_transform_script
+def test_match_ops_string_name(op: OpHandle):
+ op.match_ops("linalg.matmul")
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_iface
+ at build_transform_script
+def test_match_ops_string_iface(op: OpHandle):
+ op.match_ops("LinalgOp")
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_iface
+ at build_transform_script
+def test_match_ops_iface(op: OpHandle):
+ op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_multiple
+ at build_transform_script
+def test_match_ops_multiple(op: OpHandle):
+ op.match_ops([scf.ForOp, scf.ForallOp])
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_match_ops_mixed
+ at build_transform_script
+def test_match_ops_mixed(op: OpHandle):
+ op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.any_op
``````````
</details>
https://github.com/llvm/llvm-project/pull/75073
More information about the Mlir-commits
mailing list