[Mlir-commits] [mlir] [mlir][python] Add normalforms to capture preconditions of transforms (PR #79449)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 26 03:28:11 PST 2024


https://github.com/martin-luecke updated https://github.com/llvm/llvm-project/pull/79449

>From 2d1cbe436bfb9b1991d2360849441c128efd6f01 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Fri, 26 Jan 2024 12:27:55 +0100
Subject: [PATCH] [mlir][python] Add normalforms to capture preconditions for
 transforms

---
 .../dialects/transform/extras/__init__.py     | 83 ++++++++++++++++-
 mlir/test/python/dialects/transform_extras.py | 89 +++++++++++++++++++
 2 files changed, 171 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8d045cad7a4a36f..8f7d04ddb1f9e54 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,7 +2,7 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Callable, Optional, Sequence, Union
+from typing import Callable, Optional, Sequence, Type, TypeVar, Union
 
 from ....extras.meta import region_op
 from .... import ir
@@ -20,6 +20,8 @@
 )
 from .. import structured
 
+HandleT = TypeVar("HandleT", bound="Handle")
+
 
 class Handle(ir.Value):
     """
@@ -42,6 +44,46 @@ def __init__(
         super().__init__(v)
         self.parent = parent
         self.children = children if children is not None else []
+        self._normalForm = NormalForm
+
+    @property
+    def normalForm(self) -> Type["NormalForm"]:
+        """
+        The normalform of this handle. This is a static property of the handle
+        and indicates a group of previously applied transforms. This can be used
+        by subsequent transforms to statically reason about the structure of the
+        payload operations and whether other enabling transforms could possibly
+        be skipped.
+        Setting this property triggers propagation of the normalform to parent
+        and child handles depending on the specific normalform.
+        """
+        return self._normalForm
+
+    @normalForm.setter
+    def normalForm(self, normalForm: Type["NormalForm"]):
+        self._normalForm = normalForm
+        if self._normalForm.propagate_up:
+            self.propagate_up_normalform(normalForm)
+        if self._normalForm.propagate_down:
+            self.propagate_down_normalform(normalForm)
+
+    def propagate_up_normalform(self, normalForm: Type["NormalForm"]):
+        if self.parent:
+            # We set the parent normalform directly to avoid infinite recursion
+            # in case this normalform needs to be propagated up and down.
+            self.parent._normalForm = normalForm
+            self.parent.propagate_up_normalform(normalForm)
+
+    def propagate_down_normalform(self, normalForm: Type["NormalForm"]):
+        for child in self.children:
+            # We set the child normalform directly to avoid infinite recursion
+            # in case this normalform needs to be propagated up and down.
+            child._normalForm = normalForm
+            child.propagate_down_normalform(normalForm)
+
+    def normalize(self: HandleT, normalForm: Type["NormalForm"]) -> HandleT:
+        return normalForm.apply(self)
+
 
 @ir.register_value_caster(AnyOpType.get_static_typeid())
 @ir.register_value_caster(OperationType.get_static_typeid())
@@ -192,6 +234,45 @@ def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
     return op.param
 
 
+class NormalForm:
+    """
+    Represents the weakest normalform and is the base class for all normalforms.
+    A normalform is defined as a sequence of transforms to be applied to a
+    handle to reach this normalform.
+
+    `propagate_up`: Propagate this normalform up to parent handles.
+    `propagate_down`: Propagate this normalform down to all child handles
+    """
+
+    propagate_up: bool = True
+    propagate_down: bool = True
+
+    def __init__(self):
+        raise TypeError(
+            "NormalForm cannot be instantiated directly. Use Type[NormalForm]"
+            "instead."
+        )
+
+    @classmethod
+    def _impl(cls, handle: HandleT) -> HandleT:
+        """
+        Defines the transforms required to reach this normalform.
+        A normalform may apply arbitrary transforms and thus possibly
+        invalidate `handle`.
+        """
+        return handle
+
+    @classmethod
+    def apply(cls, handle: HandleT) -> HandleT:
+        """Apply transforms to a handle to bring it into this normalform."""
+        new_handle = cls._impl(handle)
+        new_handle.children.extend(handle.children)
+        new_handle.parent = handle.parent
+        # Setting this property propagates the normalform accordingly
+        new_handle.normalForm = cls
+        return new_handle
+
+
 def insert_transform_script(
     block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
     script: Callable[[OpHandle], None],
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index ea47f170cb63212..b329609f795b52d 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -19,6 +19,7 @@
     insert_transform_script,
     sequence,
     apply_patterns,
+    NormalForm,
 )
 from mlir.extras import types as T
 
@@ -55,6 +56,12 @@ def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]
         module.operation.verify()
 
 
+def run(f: Callable[[], None]):
+    print("\nTEST:", f.__name__)
+    with ir.Context(), ir.Location.unknown():
+        f()
+
+
 # CHECK-LABEL: TEST: test_build_script_at_insertion_point
 @build_transform_script_at_insertion_point
 def test_build_script_at_insertion_point(op: OpHandle):
@@ -175,6 +182,88 @@ def test_match_ops_mixed(op: OpHandle):
     # CHECK-SAME:     -> !transform.any_op
 
 
+# CHECK-LABEL: TEST: test_normalform_base
+ at build_transform_script
+def test_normalform_base(op: OpHandle):
+    # Normalform is the weakest normalform so op should already be in that form.
+    # Normalization to Normalform should be a no-op.
+    assert op._normalForm is NormalForm
+    op.normalize(NormalForm)
+    assert op._normalForm is NormalForm
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: transform.yield
+
+
+class DummyNormalform(NormalForm):
+    propagate_up: bool = True
+    propagate_down: bool = True
+
+    @classmethod
+    def _impl(cls, handle: OpHandle) -> OpHandle:
+        return handle.print("dummy normalization")
+
+
+# CHECK-LABEL: test_normalform_no_instantiation
+ at run
+def test_normalform_no_instantiation():
+    try:
+        DummyNormalform()
+    except TypeError as e:
+        print(e)
+    else:
+        print("Exception not produced")
+
+    # CHECK: NormalForm cannot be instantiated directly
+
+
+# CHECK-LABEL: TEST: test_normalform_dummyform
+ at build_transform_script
+def test_normalform_dummyform(op: OpHandle):
+    op.normalize(DummyNormalform)
+    assert op._normalForm is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_up
+ at build_transform_script
+def test_normalform_propagate_up(op: OpHandle):
+    nested_handle = op.match_ops("dummy.op")
+    nested_handle.normalize(DummyNormalform)
+    assert nested_handle._normalForm is DummyNormalform
+    assert op._normalForm is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
+    # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_down
+ at build_transform_script
+def test_normalform_propagate_down(op: OpHandle):
+    nested_handle = op.match_ops("dummy.op")
+    op.normalize(DummyNormalform)
+    assert nested_handle._normalForm is DummyNormalform
+    assert op._normalForm is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
+    # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_up_and_down
+ at build_transform_script
+def test_normalform_propagate_up_and_down(op: OpHandle):
+    nested_handle = op.match_ops("dummy.op1")
+    nested_nested_handle = nested_handle.match_ops("dummy.op2")
+    nested_handle.normalize(DummyNormalform)
+    assert nested_handle._normalForm is DummyNormalform
+    assert op._normalForm is DummyNormalform
+    assert nested_nested_handle._normalForm is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op1"]}
+    # CHECK-NEXT: %[[VAL_2:.*]] = transform.structured.match ops{["dummy.op2"]}
+    # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
+
+
 # CHECK-LABEL: TEST: test_print_message
 @build_transform_script
 def test_print_message(op: OpHandle):



More information about the Mlir-commits mailing list