[Mlir-commits] [mlir] [MLIR][transform][python] add sugared python abstractions for transform dialect (PR #75073)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 13 06:42:51 PST 2023
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/75073 at github.com>
https://github.com/martin-luecke updated https://github.com/llvm/llvm-project/pull/75073
>From 5ec81e059c9c1a2be91782beca501a5e0d5df474 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Wed, 6 Dec 2023 16:32:56 +0100
Subject: [PATCH 1/5] [MLIR][transform][python] add extras for sugared python
abstractions
---
mlir/python/CMakeLists.txt | 8 ++
.../dialects/transform/extras/__init__.py | 113 ++++++++++++++++++
mlir/test/python/dialects/transform_extras.py | 80 +++++++++++++
3 files changed, 201 insertions(+)
create mode 100644 mlir/python/mlir/dialects/transform/extras/__init__.py
create mode 100644 mlir/test/python/dialects/transform_extras.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 585918afc26335..8013b49dbf9d68 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -311,6 +311,14 @@ declare_mlir_dialect_python_bindings(
dialects/rocdl.py
DIALECT_NAME rocdl)
+declare_mlir_python_sources(
+ MLIRPythonSources.Dialects.transform.extras
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ GEN_ENUM_BINDINGS
+ SOURCES
+ dialects/transform/extras/__init__.py)
+
declare_mlir_python_sources(
MLIRPythonSources.Dialects.quant
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
new file mode 100644
index 00000000000000..0ccbb2e9254fd2
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -0,0 +1,113 @@
+from __future__ import annotations
+
+import abc
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence, TypeVar
+
+try:
+ from .... import ir
+ from ....dialects import transform
+ from ....dialects.transform import structured
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+ at dataclass
+class Value(abc.ABC):
+ """Wrapper around a transform value handle with methods to chain further transforms."""
+
+ _mlir_value: ir.Value
+ children: list[Value] = field(default_factory=list)
+ parent: Optional[Value] = None
+
+ @property
+ def mlir_value(self) -> ir.Value:
+ return self._mlir_value
+
+
+ at dataclass
+class Param(Value):
+ """Wrapper around a transform Param with methods to chain further transforms."""
+
+
+ at dataclass
+class OpHandle(Value):
+ """Wrapper around a transform OpHandle with methods to chain further transforms."""
+
+ def match_ops(
+ self,
+ ops: str
+ | ir.OpView
+ | structured.MatchInterfaceEnum
+ | Sequence[str | ir.OpView],
+ ) -> OpHandle:
+ """Returns a handle to ops that match the given names, types, or interface.
+
+ If only a single type is given, the value wrapped by the resulting
+ handle is populated with the respective type.
+ """
+ # Handle interface.
+ if isinstance(ops, structured.MatchInterfaceEnum) or (
+ isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+ ):
+ if isinstance(ops, str):
+ ops = structured.MatchInterfaceEnum[ops]
+ match_op = structured.MatchOp(
+ transform.AnyOpType.get(),
+ self.mlir_value,
+ interface=ops,
+ )
+
+ # Handle op name(s), either given directly as string or given as op.
+ else:
+ if isinstance(ops, str):
+ op_type = transform.OperationType.get(ops)
+ op_names = [ops]
+ elif isinstance(ops, Sequence):
+ op_type = transform.AnyOpType.get()
+ op_names = [
+ op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+ ]
+ else:
+ op_type = transform.OperationType.get(ops.OPERATION_NAME)
+ op_names = [ops.OPERATION_NAME]
+ match_op = structured.MatchOp.match_op_names(
+ op_type,
+ self.mlir_value,
+ op_names,
+ )
+
+ handle = OpHandle(match_op.results_, parent=self)
+ self.children.append(handle)
+ return handle
+
+
+ValueT = TypeVar("ValueT", bound=Value)
+
+
+def insert_transform_script(
+ module: ir.Module,
+ script: Callable[[ValueT], None],
+ dump_script: bool = False,
+) -> None:
+ """Inserts the transform script of the schedule into the module.
+
+ Args:
+ module: Existing module into which the script should be inserted.
+ script: The transform script to apply at.
+ dump_script: Whether to dump the script after creation.
+ """
+ # Insert the script into the IR
+ with module.context, ir.Location.unknown(module.context):
+ with ir.InsertionPoint.at_block_begin(module.body):
+ sequence_op = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ (),
+ transform.AnyOpType.get(),
+ )
+ with ir.InsertionPoint(sequence_op.body):
+ script(OpHandle(sequence_op.bodyTarget))
+ transform.YieldOp([])
+
+ if dump_script:
+ print(sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
new file mode 100644
index 00000000000000..8c482018868319
--- /dev/null
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -0,0 +1,80 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Callable, TypeVar
+from mlir import ir
+from mlir.dialects import scf
+from mlir.dialects.transform import structured
+from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
+
+ValueT = TypeVar("ValueT", bound=Value)
+
+
+def build_transform_script(script: Callable[[ValueT], None]):
+ print("\nTEST:", script.__name__)
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ insert_transform_script(module, script=script, dump_script=True)
+ module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_match_ops_single
+ at build_transform_script
+def test_match_ops_single(op: OpHandle):
+ op.match_ops(scf.ForOp)
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
+ # CHECK-SAME: in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.op<"scf.for">
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_name
+ at build_transform_script
+def test_match_ops_string_name(op: OpHandle):
+ op.match_ops("linalg.matmul")
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_iface
+ at build_transform_script
+def test_match_ops_string_iface(op: OpHandle):
+ op.match_ops("LinalgOp")
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_iface
+ at build_transform_script
+def test_match_ops_iface(op: OpHandle):
+ op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_multiple
+ at build_transform_script
+def test_match_ops_multiple(op: OpHandle):
+ op.match_ops([scf.ForOp, scf.ForallOp])
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_match_ops_mixed
+ at build_transform_script
+def test_match_ops_mixed(op: OpHandle):
+ op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
+ # CHECK: transform.sequence
+ # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+ # CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
+ # CHECK-SAME: -> !transform.any_op
>From 86fb474f795338fc667f34324debac679b924457 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:21:47 +0100
Subject: [PATCH 2/5] cleaning up comments
---
.../mlir/dialects/transform/extras/__init__.py | 13 ++++---------
1 file changed, 4 insertions(+), 9 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 0ccbb2e9254fd2..4b864cf42a53b6 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -9,7 +9,7 @@
from ....dialects import transform
from ....dialects.transform import structured
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports") from e
@dataclass
@@ -41,8 +41,8 @@ def match_ops(
| structured.MatchInterfaceEnum
| Sequence[str | ir.OpView],
) -> OpHandle:
- """Returns a handle to ops that match the given names, types, or interface.
-
+ """
+ Returns a handle to ops that match the given names, types, or interface.
If only a single type is given, the value wrapped by the resulting
handle is populated with the respective type.
"""
@@ -90,13 +90,8 @@ def insert_transform_script(
script: Callable[[ValueT], None],
dump_script: bool = False,
) -> None:
- """Inserts the transform script of the schedule into the module.
+ """Inserts the transform script of the schedule into the module."""
- Args:
- module: Existing module into which the script should be inserted.
- script: The transform script to apply at.
- dump_script: Whether to dump the script after creation.
- """
# Insert the script into the IR
with module.context, ir.Location.unknown(module.context):
with ir.InsertionPoint.at_block_begin(module.body):
>From ee8f9c2a8429aaa6d0e16fd786d5943ce9e12227 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:24:15 +0100
Subject: [PATCH 3/5] add in-depth documentation to `insert_transform_script`
---
.../dialects/transform/extras/__init__.py | 30 +++++++++++++++----
mlir/test/python/dialects/transform_extras.py | 8 ++---
2 files changed, 27 insertions(+), 11 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 4b864cf42a53b6..8e6788923d796e 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -82,17 +82,35 @@ def match_ops(
return handle
-ValueT = TypeVar("ValueT", bound=Value)
-
-
def insert_transform_script(
module: ir.Module,
- script: Callable[[ValueT], None],
+ script: Callable[[OpHandle], None],
dump_script: bool = False,
) -> None:
- """Inserts the transform script of the schedule into the module."""
+ """
+ Inserts the transform script of the schedule into the module. The script
+ should accept an instance of OpHandle as argument, which will be called with
+ the block arg of the newly created sequence op.
+
+ Example:
+ This python code
+ ```
+ module = ir.Module.create()
+ def test_match_ops_single(module: OpHandle):
+ module.match_ops(scf.ForOp)
+ insert_transform_script(module, script)
+ ```
+ generates the following IR:
+ ```
+ module {
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
+ }
+ }
+ ```
+ """
- # Insert the script into the IR
with module.context, ir.Location.unknown(module.context):
with ir.InsertionPoint.at_block_begin(module.body):
sequence_op = transform.SequenceOp(
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index 8c482018868319..08d853d8c2bc77 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -1,15 +1,13 @@
# RUN: %PYTHON %s | FileCheck %s
-from typing import Callable, TypeVar
+from typing import Callable
from mlir import ir
from mlir.dialects import scf
from mlir.dialects.transform import structured
-from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
+from mlir.dialects.transform.extras import OpHandle, insert_transform_script
-ValueT = TypeVar("ValueT", bound=Value)
-
-def build_transform_script(script: Callable[[ValueT], None]):
+def build_transform_script(script: Callable[[OpHandle], None]):
print("\nTEST:", script.__name__)
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
>From 7ee6bb0245b46bd6011e17bb25d4635ed2bea589 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 11 Dec 2023 17:30:09 +0100
Subject: [PATCH 4/5] remove unneeded import
---
mlir/python/mlir/dialects/transform/extras/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8e6788923d796e..9f1f752bd7dba4 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,7 +2,7 @@
import abc
from dataclasses import dataclass, field
-from typing import Callable, Optional, Sequence, TypeVar
+from typing import Callable, Optional, Sequence
try:
from .... import ir
>From fd1ddd1bca4d43c49862274c12981b36e7a1ffe8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Wed, 13 Dec 2023 15:42:33 +0100
Subject: [PATCH 5/5] Address PR comments - expose `get_static_typeid` for
Python mlir type subclasses - expose typeid getters for transform dialect
types to CAPI - sequence -> named_sequence - more general
`insert_transform_script` - use `value_caster` automation instead of
composition - clarifying documentation - cleaning up imports - add missing
license header
---
mlir/include/mlir-c/Dialect/Transform.h | 8 +
.../mlir/Bindings/Python/PybindAdaptors.h | 2 +
mlir/lib/Bindings/Python/DialectTransform.cpp | 12 +-
mlir/lib/CAPI/Dialect/Transform.cpp | 18 ++-
mlir/python/CMakeLists.txt | 16 +-
.../dialects/transform/extras/__init__.py | 126 ---------------
.../extras/dialects/transform/__init__.py | 147 ++++++++++++++++++
mlir/test/python/dialects/transform_extras.py | 45 ++++--
8 files changed, 221 insertions(+), 153 deletions(-)
delete mode 100644 mlir/python/mlir/dialects/transform/extras/__init__.py
create mode 100644 mlir/python/mlir/extras/dialects/transform/__init__.py
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 91c99b1f869f22..02c99b59218825 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -25,6 +25,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void);
+
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
@@ -33,6 +35,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void);
+
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
@@ -41,6 +45,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type);
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void);
+
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
@@ -63,6 +69,8 @@ mlirTransformOperationTypeGetOperationName(MlirType type);
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void);
+
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
MlirType type);
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 5e0e56fc00a673..125f9b92937bff 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -495,6 +495,8 @@ class mlir_type_subclass : public pure_subclass {
.attr("replace")(superCls.attr("__name__"), captureTypeName);
});
if (getTypeIDFunction) {
+ def_staticmethod("get_static_typeid",
+ [getTypeIDFunction]() { return getTypeIDFunction(); });
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
getTypeIDFunction())(pybind11::cpp_function(
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index c7764f4e7aecac..6b57e652aa9d8b 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -27,7 +27,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//
auto anyOpType =
- mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
+ mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
+ mlirTransformAnyOpTypeGetTypeID);
anyOpType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
@@ -41,7 +42,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//
auto anyParamType =
- mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
+ mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
+ mlirTransformAnyParamTypeGetTypeID);
anyParamType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
@@ -55,7 +57,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//
auto anyValueType =
- mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType);
+ mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
+ mlirTransformAnyValueTypeGetTypeID);
anyValueType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
@@ -96,7 +99,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//
auto paramType =
- mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
+ mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
+ mlirTransformParamTypeGetTypeID);
paramType.def_classmethod(
"get",
[](py::object cls, MlirType type, MlirContext ctx) {
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 3f7f8b8e2113fe..5fd773572bd3c8 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -25,6 +25,10 @@ bool mlirTypeIsATransformAnyOpType(MlirType type) {
return isa<transform::AnyOpType>(unwrap(type));
}
+MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) {
+ return wrap(transform::AnyOpType::getTypeID());
+}
+
MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
@@ -37,6 +41,10 @@ bool mlirTypeIsATransformAnyParamType(MlirType type) {
return isa<transform::AnyParamType>(unwrap(type));
}
+MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) {
+ return wrap(transform::AnyParamType::getTypeID());
+}
+
MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
return wrap(transform::AnyParamType::get(unwrap(ctx)));
}
@@ -49,6 +57,10 @@ bool mlirTypeIsATransformAnyValueType(MlirType type) {
return isa<transform::AnyValueType>(unwrap(type));
}
+MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) {
+ return wrap(transform::AnyValueType::getTypeID());
+}
+
MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
return wrap(transform::AnyValueType::get(unwrap(ctx)));
}
@@ -76,13 +88,17 @@ MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
}
//===---------------------------------------------------------------------===//
-// AnyOpType
+// ParamType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformParamType(MlirType type) {
return isa<transform::ParamType>(unwrap(type));
}
+MlirTypeID mlirTransformParamTypeGetTypeID(void) {
+ return wrap(transform::ParamType::getTypeID());
+}
+
MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8013b49dbf9d68..41d91cf6778338 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -166,6 +166,14 @@ declare_mlir_dialect_python_bindings(
"../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
)
+declare_mlir_python_sources(
+ MLIRPythonSources.Dialects.transform.extras
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ GEN_ENUM_BINDINGS
+ SOURCES
+ extras/dialects/transform/__init__.py)
+
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -311,14 +319,6 @@ declare_mlir_dialect_python_bindings(
dialects/rocdl.py
DIALECT_NAME rocdl)
-declare_mlir_python_sources(
- MLIRPythonSources.Dialects.transform.extras
- ADD_TO_PARENT MLIRPythonSources.Dialects
- ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
- GEN_ENUM_BINDINGS
- SOURCES
- dialects/transform/extras/__init__.py)
-
declare_mlir_python_sources(
MLIRPythonSources.Dialects.quant
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
deleted file mode 100644
index 9f1f752bd7dba4..00000000000000
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ /dev/null
@@ -1,126 +0,0 @@
-from __future__ import annotations
-
-import abc
-from dataclasses import dataclass, field
-from typing import Callable, Optional, Sequence
-
-try:
- from .... import ir
- from ....dialects import transform
- from ....dialects.transform import structured
-except ImportError as e:
- raise RuntimeError("Error loading imports") from e
-
-
- at dataclass
-class Value(abc.ABC):
- """Wrapper around a transform value handle with methods to chain further transforms."""
-
- _mlir_value: ir.Value
- children: list[Value] = field(default_factory=list)
- parent: Optional[Value] = None
-
- @property
- def mlir_value(self) -> ir.Value:
- return self._mlir_value
-
-
- at dataclass
-class Param(Value):
- """Wrapper around a transform Param with methods to chain further transforms."""
-
-
- at dataclass
-class OpHandle(Value):
- """Wrapper around a transform OpHandle with methods to chain further transforms."""
-
- def match_ops(
- self,
- ops: str
- | ir.OpView
- | structured.MatchInterfaceEnum
- | Sequence[str | ir.OpView],
- ) -> OpHandle:
- """
- Returns a handle to ops that match the given names, types, or interface.
- If only a single type is given, the value wrapped by the resulting
- handle is populated with the respective type.
- """
- # Handle interface.
- if isinstance(ops, structured.MatchInterfaceEnum) or (
- isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
- ):
- if isinstance(ops, str):
- ops = structured.MatchInterfaceEnum[ops]
- match_op = structured.MatchOp(
- transform.AnyOpType.get(),
- self.mlir_value,
- interface=ops,
- )
-
- # Handle op name(s), either given directly as string or given as op.
- else:
- if isinstance(ops, str):
- op_type = transform.OperationType.get(ops)
- op_names = [ops]
- elif isinstance(ops, Sequence):
- op_type = transform.AnyOpType.get()
- op_names = [
- op if isinstance(op, str) else op.OPERATION_NAME for op in ops
- ]
- else:
- op_type = transform.OperationType.get(ops.OPERATION_NAME)
- op_names = [ops.OPERATION_NAME]
- match_op = structured.MatchOp.match_op_names(
- op_type,
- self.mlir_value,
- op_names,
- )
-
- handle = OpHandle(match_op.results_, parent=self)
- self.children.append(handle)
- return handle
-
-
-def insert_transform_script(
- module: ir.Module,
- script: Callable[[OpHandle], None],
- dump_script: bool = False,
-) -> None:
- """
- Inserts the transform script of the schedule into the module. The script
- should accept an instance of OpHandle as argument, which will be called with
- the block arg of the newly created sequence op.
-
- Example:
- This python code
- ```
- module = ir.Module.create()
- def test_match_ops_single(module: OpHandle):
- module.match_ops(scf.ForOp)
- insert_transform_script(module, script)
- ```
- generates the following IR:
- ```
- module {
- transform.sequence failures(propagate) {
- ^bb0(%arg0: !transform.any_op):
- %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
- }
- }
- ```
- """
-
- with module.context, ir.Location.unknown(module.context):
- with ir.InsertionPoint.at_block_begin(module.body):
- sequence_op = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- (),
- transform.AnyOpType.get(),
- )
- with ir.InsertionPoint(sequence_op.body):
- script(OpHandle(sequence_op.bodyTarget))
- transform.YieldOp([])
-
- if dump_script:
- print(sequence_op)
diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/extras/dialects/transform/__init__.py
new file mode 100644
index 00000000000000..5870fa4fb1661d
--- /dev/null
+++ b/mlir/python/mlir/extras/dialects/transform/__init__.py
@@ -0,0 +1,147 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from __future__ import annotations
+from typing import Callable, Optional, Sequence
+
+from .... import ir
+from ....dialects import transform
+from ....dialects.transform import structured
+
+
+class Handle(ir.Value):
+ """
+ Base class for wrappers around different types of transform handle with
+ methods to chain further transforms.
+
+ The fields `children` and `parent` are used to capture the relation of
+ handles statically in order to enable further analysis. The payload
+ operation of a child handle is nested into a region of the payload operation
+ of the corresponding parent handle.
+ """
+
+ def __init__(
+ self,
+ v: ir.Value,
+ *,
+ parent: Optional[Handle] = None,
+ children: Optional[Sequence[Handle]] = None,
+ ):
+ super().__init__(v)
+ self.parent = parent
+ self.children = children if children is not None else []
+
+
+ at ir.register_value_caster(transform.AnyOpType.get_static_typeid())
+ at ir.register_value_caster(transform.OperationType.get_static_typeid())
+class OpHandle(Handle):
+ """
+ Wrapper around a transform operation handle with methods to chain further
+ transforms.
+ """
+
+ def __init__(
+ self,
+ v: ir.Value,
+ *,
+ parent: Optional[Handle] = None,
+ children: Optional[Sequence[Handle]] = None,
+ ):
+ super().__init__(v, parent=parent, children=children)
+
+ def match_ops(
+ self,
+ ops: str
+ | ir.OpView
+ | structured.MatchInterfaceEnum
+ | Sequence[str | ir.OpView],
+ ) -> OpHandle:
+ """
+ Returns a handle to ops that match the given names, types, or interface.
+ If only a single type is given, the value wrapped by the resulting
+ handle is populated with the respective type.
+ """
+ # Handle interface.
+ if isinstance(ops, structured.MatchInterfaceEnum) or (
+ isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+ ):
+ if isinstance(ops, str):
+ ops = structured.MatchInterfaceEnum[ops]
+ match_op = structured.MatchOp(
+ transform.AnyOpType.get(),
+ self,
+ interface=ops,
+ )
+
+ # Handle op name(s), either given directly as string or given as op.
+ else:
+ if isinstance(ops, str):
+ op_type = transform.OperationType.get(ops)
+ op_names = [ops]
+ elif isinstance(ops, Sequence):
+ op_type = transform.AnyOpType.get()
+ op_names = [
+ op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+ ]
+ else:
+ op_type = transform.OperationType.get(ops.OPERATION_NAME)
+ op_names = [ops.OPERATION_NAME]
+ match_op = structured.MatchOp.match_op_names(
+ op_type,
+ self,
+ op_names,
+ )
+
+ handle = OpHandle(match_op.results_, parent=self)
+ self.children.append(handle)
+ return handle
+
+
+def insert_transform_script(
+ block_or_insertion_point: ir.Block | ir.InsertionPoint,
+ script: Callable[[OpHandle], None],
+ dump_script: bool = False,
+) -> None:
+ """
+ Inserts the transform script of the schedule into the module. The script
+ should accept an instance of OpHandle as argument, which will be called with
+ the block arg of the newly created named_sequence op.
+
+ Example:
+ This python code
+ ```
+ module = ir.Module.create()
+ def test_match_ops_single(module: OpHandle):
+ module.match_ops(scf.ForOp)
+ insert_transform_script(module.body, script)
+ ```
+ generates the following IR:
+ ```
+ module {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0
+ : (!transform.any_op) -> !transform.op<"scf.for">
+ }
+ }
+ ```
+ """
+ if isinstance(block_or_insertion_point, ir.Block):
+ context = block_or_insertion_point.owner.context
+ insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point)
+ else:
+ context = block_or_insertion_point.block.owner.context
+ insertion_point = block_or_insertion_point
+
+ with context, ir.Location.unknown(context):
+ with insertion_point:
+ named_sequence_op = transform.NamedSequenceOp(
+ "__transform_main", [transform.AnyOpType.get()], []
+ )
+ with ir.InsertionPoint(named_sequence_op.body):
+ script(named_sequence_op.bodyTarget)
+ transform.YieldOp([])
+
+ if dump_script:
+ print(named_sequence_op)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index 08d853d8c2bc77..dbfa8a2dc73c41 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -4,23 +4,45 @@
from mlir import ir
from mlir.dialects import scf
from mlir.dialects.transform import structured
-from mlir.dialects.transform.extras import OpHandle, insert_transform_script
+from mlir.extras.dialects.transform import OpHandle, insert_transform_script
def build_transform_script(script: Callable[[OpHandle], None]):
print("\nTEST:", script.__name__)
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
- insert_transform_script(module, script=script, dump_script=True)
+ module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+ insert_transform_script(module.body, script=script, dump_script=True)
module.operation.verify()
+def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]):
+ print("\nTEST:", script.__name__)
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+ insert_transform_script(
+ ir.InsertionPoint.at_block_begin(module.body),
+ script=script,
+ dump_script=True,
+ )
+ module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_build_script_at_insertion_point
+ at build_transform_script_at_insertion_point
+def test_build_script_at_insertion_point(op: OpHandle):
+ pass
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: transform.yield
+ # CHECK-NEXT: }
+
+
# CHECK-LABEL: TEST: test_match_ops_single
@build_transform_script
def test_match_ops_single(op: OpHandle):
op.match_ops(scf.ForOp)
- # CHECK: transform.sequence
- # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
# CHECK-SAME: in %[[VAL_0]]
# CHECK-SAME: -> !transform.op<"scf.for">
@@ -30,8 +52,7 @@ def test_match_ops_single(op: OpHandle):
@build_transform_script
def test_match_ops_string_name(op: OpHandle):
op.match_ops("linalg.matmul")
- # CHECK: transform.sequence
- # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]]
@@ -40,8 +61,7 @@ def test_match_ops_string_name(op: OpHandle):
@build_transform_script
def test_match_ops_string_iface(op: OpHandle):
op.match_ops("LinalgOp")
- # CHECK: transform.sequence
- # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
@@ -50,8 +70,7 @@ def test_match_ops_string_iface(op: OpHandle):
@build_transform_script
def test_match_ops_iface(op: OpHandle):
op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
- # CHECK: transform.sequence
- # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
@@ -60,8 +79,7 @@ def test_match_ops_iface(op: OpHandle):
@build_transform_script
def test_match_ops_multiple(op: OpHandle):
op.match_ops([scf.ForOp, scf.ForallOp])
- # CHECK: transform.sequence
- # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
# CHECK-SAME: -> !transform.any_op
@@ -71,8 +89,7 @@ def test_match_ops_multiple(op: OpHandle):
@build_transform_script
def test_match_ops_mixed(op: OpHandle):
op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
- # CHECK: transform.sequence
- # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
# CHECK-SAME: -> !transform.any_op
More information about the Mlir-commits
mailing list