[Mlir-commits] [mlir] [MLIR][transform][python] add sugared python abstractions for transform dialect (PR #75073)

Maksim Levental llvmlistbot at llvm.org
Mon Dec 11 09:51:21 PST 2023


Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/75073 at github.com>


================
@@ -0,0 +1,126 @@
+from __future__ import annotations
+
+import abc
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence
+
+try:
+    from .... import ir
+    from ....dialects import transform
+    from ....dialects.transform import structured
+except ImportError as e:
+    raise RuntimeError("Error loading imports") from e
+
+
+ at dataclass
+class Value(abc.ABC):
+    """Wrapper around a transform value handle with methods to chain further transforms."""
+
+    _mlir_value: ir.Value
+    children: list[Value] = field(default_factory=list)
+    parent: Optional[Value] = None
+
+    @property
+    def mlir_value(self) -> ir.Value:
+        return self._mlir_value
+
+
+ at dataclass
+class Param(Value):
+    """Wrapper around a transform Param with methods to chain further transforms."""
+
+
+ at dataclass
+class OpHandle(Value):
+    """Wrapper around a transform OpHandle with methods to chain further transforms."""
+
+    def match_ops(
+        self,
+        ops: str
+        | ir.OpView
+        | structured.MatchInterfaceEnum
+        | Sequence[str | ir.OpView],
+    ) -> OpHandle:
+        """
+        Returns a handle to ops that match the given names, types, or interface.
+        If only a single type is given, the value wrapped by the resulting
+        handle is populated with the respective type.
+        """
+        # Handle interface.
+        if isinstance(ops, structured.MatchInterfaceEnum) or (
+            isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+        ):
+            if isinstance(ops, str):
+                ops = structured.MatchInterfaceEnum[ops]
+            match_op = structured.MatchOp(
+                transform.AnyOpType.get(),
+                self.mlir_value,
+                interface=ops,
+            )
+
+        # Handle op name(s), either given directly as string or given as op.
+        else:
+            if isinstance(ops, str):
+                op_type = transform.OperationType.get(ops)
+                op_names = [ops]
+            elif isinstance(ops, Sequence):
+                op_type = transform.AnyOpType.get()
+                op_names = [
+                    op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+                ]
+            else:
+                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_names = [ops.OPERATION_NAME]
+            match_op = structured.MatchOp.match_op_names(
+                op_type,
+                self.mlir_value,
+                op_names,
+            )
+
+        handle = OpHandle(match_op.results_, parent=self)
+        self.children.append(handle)
+        return handle
+
+
+def insert_transform_script(
+    module: ir.Module,
+    script: Callable[[OpHandle], None],
+    dump_script: bool = False,
+) -> None:
+    """
+    Inserts the transform script of the schedule into the module. The script
+    should accept an instance of OpHandle as argument, which will be called with
+    the block arg of the newly created sequence op.
+
+    Example:
+    This python code
+    ```
+    module = ir.Module.create()
+    def test_match_ops_single(module: OpHandle):
+        module.match_ops(scf.ForOp)
+    insert_transform_script(module, script)
+    ```
+    generates the following IR:
+    ```
+    module {
+        transform.sequence failures(propagate) {
+        ^bb0(%arg0: !transform.any_op):
+            %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
+        }
+    }
+    ```
+    """
+
+    with module.context, ir.Location.unknown(module.context):
+        with ir.InsertionPoint.at_block_begin(module.body):
+            sequence_op = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                (),
+                transform.AnyOpType.get(),
+            )
+        with ir.InsertionPoint(sequence_op.body):
+            script(OpHandle(sequence_op.bodyTarget))
----------------
makslevental wrote:

You don't need to do this - you can [`register_value_caster(OpHandle)`](https://github.com/llvm/llvm-project/blob/7c850867b9ef4427375da6d83c34d0b9c944fcb8/mlir/test/python/ir/value.py#L372) and it'll automagically be done for you _everywhere_ `ir.Values` are returned from the bindings. Doing this will have the added benefit that you don't need the proxy/shadow values - you're "casted" values will just be subclasses of `ir.Value` already (i.e., you'll be able to ditch the indirection through `self.mlir_value`).

Quickly reminding myself about how the transform dialect bindings work, the only thing you're missing there is exposure of `TypeID` for some of the transform types (looks like [`OperationTypeGetTypeID`](https://github.com/llvm/llvm-project/blob/97f9f1a08ab1f5f91282cf95d13f306d03dc0888/mlir/lib/CAPI/Dialect/Transform.cpp#L64) is exposed but none of the others).

https://github.com/llvm/llvm-project/pull/75073


More information about the Mlir-commits mailing list