[Mlir-commits] [mlir] [MLIR][transform][python] Introduce abstractions for handles to values and parameters (PR #77305)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 8 05:49:04 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (martin-luecke)

<details>
<summary>Changes</summary>

In addition to the existing `OpHandle` which provides an abstraction to emit transform ops targeting operations this introduces a similar concept for _values_ and _parameters_ in form of `ValueHandle` and `ParamHandle`.

New core transform abstractions:
- `constant_param`
- `OpHandle.get_result`
- `OpHandle.print`
- `ValueHandle.get_defining_op`

---
Full diff: https://github.com/llvm/llvm-project/pull/77305.diff


2 Files Affected:

- (modified) mlir/python/mlir/dialects/transform/extras/__init__.py (+85) 
- (modified) mlir/test/python/dialects/transform_extras.py (+71) 


``````````diff
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index e4d47e9064f2c8..ba51c400fe2cb2 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -6,9 +6,13 @@
 
 from ....extras.meta import region_op
 from .... import ir
+from ... import transform
 from .. import (
     AnyOpType,
+    AnyParamType,
+    AnyValueType,
     OperationType,
+    ParamType,
     NamedSequenceOp,
     YieldOp,
     SequenceOp,
@@ -57,6 +61,19 @@ def __init__(
     ):
         super().__init__(v, parent=parent, children=children)
 
+    def get_result(self, idx: int = 0) -> "ValueHandle":
+        """
+        Emits a `transform.GetResultOp`.
+        Returns a handle to the result of the payload operation at the given
+        index.
+        """
+        get_result_op = transform.GetResultOp(
+            AnyValueType.get(),
+            self,
+            idx,
+        )
+        return get_result_op.result
+
     def match_ops(
         self,
         ops: Union[
@@ -107,6 +124,74 @@ def match_ops(
         self.children.append(handle)
         return handle
 
+    def print(self, name: Optional[str] = None) -> "OpHandle":
+        """
+        Emits a `transform.PrintOp` to print this handle and an optional message.
+        Returns the existing handle to facilitate further chaining.
+        """
+        transform.PrintOp(target=self, name=name)
+        return self
+
+
+ at ir.register_value_caster(AnyParamType.get_static_typeid())
+ at ir.register_value_caster(ParamType.get_static_typeid())
+class ParamHandle(Handle):
+    """Wrapper around a transform param handle."""
+
+    def __init__(
+        self,
+        v: ir.Value,
+        *,
+        parent: Optional[Handle] = None,
+        children: Optional[Sequence[Handle]] = None,
+    ):
+        super().__init__(v, parent=parent, children=children)
+
+
+ at ir.register_value_caster(AnyValueType.get_static_typeid())
+class ValueHandle(Handle):
+    """
+    Wrapper around a transform value handle with methods to chain further
+    transforms.
+    """
+
+    def __init__(
+        self,
+        v: ir.Value,
+        *,
+        parent: Optional[Handle] = None,
+        children: Optional[Sequence[Handle]] = None,
+    ):
+        super().__init__(v, parent=parent, children=children)
+
+    def get_defining_op(self) -> OpHandle:
+        """
+        Emits a `transform.GetDefiningOpOp`.
+        Returns a handle to the defining op of the wrapped value.
+        """
+        get_defining_op = transform.GetDefiningOp(
+            AnyOpType.get(),
+            self,
+        )
+        return get_defining_op.result
+
+
+def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
+    """
+    Emits a `transform.ParamConstantOp`.
+    Returns a handle to the newly created parameter. The type of the parameter
+    is `transfrom.any_param` if the value is not an integer, otherwise the type
+    is `transform.param` parametrized with the according integer type.
+    """
+    if isinstance(value, int):
+        value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)
+    if isinstance(value.type, ir.IntegerType):
+        param_type = ParamType.get(value.type)
+    else:
+        param_type = AnyParamType.get()
+    op = transform.ParamConstantOp(param_type, value)
+    return op.param
+
 
 def insert_transform_script(
     block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index 358f8c32f75c75..ea47f170cb6321 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -14,6 +14,7 @@
 from mlir.dialects.transform.structured import structured_match
 from mlir.dialects.transform.loop import loop_unroll
 from mlir.dialects.transform.extras import (
+    constant_param,
     OpHandle,
     insert_transform_script,
     sequence,
@@ -63,6 +64,60 @@ def test_build_script_at_insertion_point(op: OpHandle):
     # CHECK-NEXT: }
 
 
+# CHECK-LABEL: TEST: test_constant_param_int
+ at build_transform_script
+def test_constant_param_int(_: OpHandle):
+    constant_param(ir.IntegerAttr.get(T.i32(), 42))
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i32
+    # CHECK-SAME:   !transform.param<i32>
+
+
+# CHECK-LABEL: TEST: test_constant_param_py_int
+ at build_transform_script
+def test_constant_param_py_int(_: OpHandle):
+    constant_param(42)
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i64
+    # CHECK-SAME:   !transform.param<i64>
+
+
+# CHECK-LABEL: TEST: test_constant_param_symbol_attr
+ at build_transform_script
+def test_constant_param_symbol_attr(_: OpHandle):
+    constant_param(ir.SymbolRefAttr.get(["symbol"]))
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant @symbol
+    # CHECK-SAME:   !transform.any_param
+
+
+# CHECK-LABEL: TEST: test_constant_param_type
+ at build_transform_script
+def test_constant_param_type(_: OpHandle):
+    constant_param(ir.TypeAttr.get(T.i32()))
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant i32
+    # CHECK-SAME:   !transform.any_param
+
+
+# CHECK-LABEL: TEST: test_get_defining_op
+ at build_transform_script
+def test_get_defining_op(op: OpHandle):
+    op.get_result().get_defining_op()
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
+    # CHECK-SAME:   !transform.any_value
+    # CHECK-NEXT: %[[VAL_2:.*]] = transform.get_defining_op %[[VAL_1]]
+
+
+# CHECK-LABEL: TEST: test_get_result
+ at build_transform_script
+def test_get_result(op: OpHandle):
+    op.get_result()
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
+
+
 # CHECK-LABEL: TEST: test_match_ops_single
 @build_transform_script
 def test_match_ops_single(op: OpHandle):
@@ -120,6 +175,22 @@ def test_match_ops_mixed(op: OpHandle):
     # CHECK-SAME:     -> !transform.any_op
 
 
+# CHECK-LABEL: TEST: test_print_message
+ at build_transform_script
+def test_print_message(op: OpHandle):
+    op.print("message")
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: transform.print %[[VAL_0]] {name = "message"}
+
+
+# CHECK-LABEL: TEST: test_print_plain
+ at build_transform_script
+def test_print_plain(op: OpHandle):
+    op.print()
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: transform.print %[[VAL_0]]
+
+
 # CHECK-LABEL: TEST: test_sequence_region
 @construct_and_print_in_module
 def test_sequence_region():

``````````

</details>


https://github.com/llvm/llvm-project/pull/77305


More information about the Mlir-commits mailing list