[Mlir-commits] [mlir] [MLIR][transform][python] add sugared python abstractions for transform dialect (PR #75073)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 13 06:49:37 PST 2023


Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>,
Martin =?utf-8?q?Lücke?= <martin.luecke at ed.ac.uk>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/75073 at github.com>


================
@@ -0,0 +1,126 @@
+from __future__ import annotations
+
+import abc
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Sequence
+
+try:
+    from .... import ir
+    from ....dialects import transform
+    from ....dialects.transform import structured
+except ImportError as e:
+    raise RuntimeError("Error loading imports") from e
+
+
+ at dataclass
+class Value(abc.ABC):
+    """Wrapper around a transform value handle with methods to chain further transforms."""
+
+    _mlir_value: ir.Value
+    children: list[Value] = field(default_factory=list)
+    parent: Optional[Value] = None
+
+    @property
+    def mlir_value(self) -> ir.Value:
+        return self._mlir_value
+
+
+ at dataclass
+class Param(Value):
+    """Wrapper around a transform Param with methods to chain further transforms."""
+
+
+ at dataclass
+class OpHandle(Value):
+    """Wrapper around a transform OpHandle with methods to chain further transforms."""
+
+    def match_ops(
+        self,
+        ops: str
+        | ir.OpView
+        | structured.MatchInterfaceEnum
+        | Sequence[str | ir.OpView],
+    ) -> OpHandle:
+        """
+        Returns a handle to ops that match the given names, types, or interface.
+        If only a single type is given, the value wrapped by the resulting
+        handle is populated with the respective type.
+        """
+        # Handle interface.
+        if isinstance(ops, structured.MatchInterfaceEnum) or (
+            isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+        ):
+            if isinstance(ops, str):
+                ops = structured.MatchInterfaceEnum[ops]
+            match_op = structured.MatchOp(
+                transform.AnyOpType.get(),
+                self.mlir_value,
+                interface=ops,
+            )
+
+        # Handle op name(s), either given directly as string or given as op.
+        else:
+            if isinstance(ops, str):
+                op_type = transform.OperationType.get(ops)
+                op_names = [ops]
+            elif isinstance(ops, Sequence):
+                op_type = transform.AnyOpType.get()
+                op_names = [
+                    op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+                ]
+            else:
+                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_names = [ops.OPERATION_NAME]
+            match_op = structured.MatchOp.match_op_names(
+                op_type,
+                self.mlir_value,
+                op_names,
+            )
+
+        handle = OpHandle(match_op.results_, parent=self)
+        self.children.append(handle)
+        return handle
+
+
+def insert_transform_script(
+    module: ir.Module,
+    script: Callable[[OpHandle], None],
+    dump_script: bool = False,
+) -> None:
+    """
+    Inserts the transform script of the schedule into the module. The script
+    should accept an instance of OpHandle as argument, which will be called with
+    the block arg of the newly created sequence op.
+
+    Example:
+    This python code
+    ```
+    module = ir.Module.create()
+    def test_match_ops_single(module: OpHandle):
+        module.match_ops(scf.ForOp)
+    insert_transform_script(module, script)
+    ```
+    generates the following IR:
+    ```
+    module {
+        transform.sequence failures(propagate) {
+        ^bb0(%arg0: !transform.any_op):
+            %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
+        }
+    }
+    ```
+    """
+
+    with module.context, ir.Location.unknown(module.context):
+        with ir.InsertionPoint.at_block_begin(module.body):
+            sequence_op = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                (),
+                transform.AnyOpType.get(),
+            )
+        with ir.InsertionPoint(sequence_op.body):
+            script(OpHandle(sequence_op.bodyTarget))
----------------
martin-luecke wrote:

I went with your suggestion and exposed the corresponding typeid getters

https://github.com/llvm/llvm-project/pull/75073


More information about the Mlir-commits mailing list