[Mlir-commits] [mlir] [MLIR][transform][python] Introduce abstractions for handles to values and parameters (PR #77305)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 8 05:48:33 PST 2024
https://github.com/martin-luecke created https://github.com/llvm/llvm-project/pull/77305
In addition to the existing `OpHandle` which provides an abstraction to emit transform ops targeting operations this introduces a similar concept for _values_ and _parameters_ in form of `ValueHandle` and `ParamHandle`.
New core transform abstractions:
- `constant_param`
- `OpHandle.get_result`
- `OpHandle.print`
- `ValueHandle.get_defining_op`
>From 2a48c4d76ac7e071d537783f9c781388e4bee0d4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Mon, 8 Jan 2024 14:47:17 +0100
Subject: [PATCH] Introduce ValueHandle and ParameterHandle
---
.../dialects/transform/extras/__init__.py | 85 +++++++++++++++++++
mlir/test/python/dialects/transform_extras.py | 71 ++++++++++++++++
2 files changed, 156 insertions(+)
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index e4d47e9064f2c8..ba51c400fe2cb2 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -6,9 +6,13 @@
from ....extras.meta import region_op
from .... import ir
+from ... import transform
from .. import (
AnyOpType,
+ AnyParamType,
+ AnyValueType,
OperationType,
+ ParamType,
NamedSequenceOp,
YieldOp,
SequenceOp,
@@ -57,6 +61,19 @@ def __init__(
):
super().__init__(v, parent=parent, children=children)
+ def get_result(self, idx: int = 0) -> "ValueHandle":
+ """
+ Emits a `transform.GetResultOp`.
+ Returns a handle to the result of the payload operation at the given
+ index.
+ """
+ get_result_op = transform.GetResultOp(
+ AnyValueType.get(),
+ self,
+ idx,
+ )
+ return get_result_op.result
+
def match_ops(
self,
ops: Union[
@@ -107,6 +124,74 @@ def match_ops(
self.children.append(handle)
return handle
+ def print(self, name: Optional[str] = None) -> "OpHandle":
+ """
+ Emits a `transform.PrintOp` to print this handle and an optional message.
+ Returns the existing handle to facilitate further chaining.
+ """
+ transform.PrintOp(target=self, name=name)
+ return self
+
+
+ at ir.register_value_caster(AnyParamType.get_static_typeid())
+ at ir.register_value_caster(ParamType.get_static_typeid())
+class ParamHandle(Handle):
+ """Wrapper around a transform param handle."""
+
+ def __init__(
+ self,
+ v: ir.Value,
+ *,
+ parent: Optional[Handle] = None,
+ children: Optional[Sequence[Handle]] = None,
+ ):
+ super().__init__(v, parent=parent, children=children)
+
+
+ at ir.register_value_caster(AnyValueType.get_static_typeid())
+class ValueHandle(Handle):
+ """
+ Wrapper around a transform value handle with methods to chain further
+ transforms.
+ """
+
+ def __init__(
+ self,
+ v: ir.Value,
+ *,
+ parent: Optional[Handle] = None,
+ children: Optional[Sequence[Handle]] = None,
+ ):
+ super().__init__(v, parent=parent, children=children)
+
+ def get_defining_op(self) -> OpHandle:
+ """
+ Emits a `transform.GetDefiningOpOp`.
+ Returns a handle to the defining op of the wrapped value.
+ """
+ get_defining_op = transform.GetDefiningOp(
+ AnyOpType.get(),
+ self,
+ )
+ return get_defining_op.result
+
+
+def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
+ """
+ Emits a `transform.ParamConstantOp`.
+ Returns a handle to the newly created parameter. The type of the parameter
+ is `transfrom.any_param` if the value is not an integer, otherwise the type
+ is `transform.param` parametrized with the according integer type.
+ """
+ if isinstance(value, int):
+ value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)
+ if isinstance(value.type, ir.IntegerType):
+ param_type = ParamType.get(value.type)
+ else:
+ param_type = AnyParamType.get()
+ op = transform.ParamConstantOp(param_type, value)
+ return op.param
+
def insert_transform_script(
block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index 358f8c32f75c75..ea47f170cb6321 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -14,6 +14,7 @@
from mlir.dialects.transform.structured import structured_match
from mlir.dialects.transform.loop import loop_unroll
from mlir.dialects.transform.extras import (
+ constant_param,
OpHandle,
insert_transform_script,
sequence,
@@ -63,6 +64,60 @@ def test_build_script_at_insertion_point(op: OpHandle):
# CHECK-NEXT: }
+# CHECK-LABEL: TEST: test_constant_param_int
+ at build_transform_script
+def test_constant_param_int(_: OpHandle):
+ constant_param(ir.IntegerAttr.get(T.i32(), 42))
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i32
+ # CHECK-SAME: !transform.param<i32>
+
+
+# CHECK-LABEL: TEST: test_constant_param_py_int
+ at build_transform_script
+def test_constant_param_py_int(_: OpHandle):
+ constant_param(42)
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i64
+ # CHECK-SAME: !transform.param<i64>
+
+
+# CHECK-LABEL: TEST: test_constant_param_symbol_attr
+ at build_transform_script
+def test_constant_param_symbol_attr(_: OpHandle):
+ constant_param(ir.SymbolRefAttr.get(["symbol"]))
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant @symbol
+ # CHECK-SAME: !transform.any_param
+
+
+# CHECK-LABEL: TEST: test_constant_param_type
+ at build_transform_script
+def test_constant_param_type(_: OpHandle):
+ constant_param(ir.TypeAttr.get(T.i32()))
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant i32
+ # CHECK-SAME: !transform.any_param
+
+
+# CHECK-LABEL: TEST: test_get_defining_op
+ at build_transform_script
+def test_get_defining_op(op: OpHandle):
+ op.get_result().get_defining_op()
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
+ # CHECK-SAME: !transform.any_value
+ # CHECK-NEXT: %[[VAL_2:.*]] = transform.get_defining_op %[[VAL_1]]
+
+
+# CHECK-LABEL: TEST: test_get_result
+ at build_transform_script
+def test_get_result(op: OpHandle):
+ op.get_result()
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
+
+
# CHECK-LABEL: TEST: test_match_ops_single
@build_transform_script
def test_match_ops_single(op: OpHandle):
@@ -120,6 +175,22 @@ def test_match_ops_mixed(op: OpHandle):
# CHECK-SAME: -> !transform.any_op
+# CHECK-LABEL: TEST: test_print_message
+ at build_transform_script
+def test_print_message(op: OpHandle):
+ op.print("message")
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: transform.print %[[VAL_0]] {name = "message"}
+
+
+# CHECK-LABEL: TEST: test_print_plain
+ at build_transform_script
+def test_print_plain(op: OpHandle):
+ op.print()
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: transform.print %[[VAL_0]]
+
+
# CHECK-LABEL: TEST: test_sequence_region
@construct_and_print_in_module
def test_sequence_region():
More information about the Mlir-commits
mailing list