[Mlir-commits] [mlir] [mlir][python] Add normalforms to capture preconditions of transforms (PR #79449)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 25 06:18:12 PST 2024


https://github.com/martin-luecke created https://github.com/llvm/llvm-project/pull/79449

This adds the concept of a _Normalform_ to the transform dialect python extras.

Normalforms are defined by a sequence of transforms required to ensure this normalform. The term and design is inspired by a similar concept in the term rewriting community.

Each transform handle adheres to a normalform. Initially simply `Normalform`, the weakest normalform, that all other normalforms inherit from.
A normalform might imply other normalforms which is currently modeled using inheritance. i.e. stronger normalforms inherit from weaker normalforms.

Note that the normalform of a handle is a static property currently without precise info about the actual payload.

Normalizing a handle to a normalform might trigger propagation of this form to parent and child handles, depending on the specific normalform.
An example for a (conceived) normalform is the `PerfectForNestForm` that aims to achieve perfect loop nests in the IR, if possible, by using canonicalization and loop invariant code motion transforms.
e.g.
```python
class PerfectForNestForm(Normalform):
  propagate_up = False
  propagate_down = True

  def _impl(cls, handle: OpHandle) -> None:
    with handle.apply_patterns(): 
      structured.ApplyTilingCanonicalizationPatternsOp()
      loop.ApplyForLoopCanonicalizationPatternsOp() 
      transform.ApplyCanonicalizationPatternsOp()
    handle.apply_licm()
    handle.apply_cse()
```

This normalform is propagated only to child handles, as all handles to operations that are nested at a deeper level will have also been impacted by these transforms and consist of perfectly nested loops, if possible. This normalform is not propagated to parent handles as these are not impacted by this specific normalization.

With the current design Normalforms are never instantiated. Only ever a type of normalform is used, e.g.
```python
class Handle(ir.Value):
  @property
  def normalform(self) -> Type["Normalform"]:
```
This could possibly also be modeled as an `Enum`, but that makes modeling a hierarchy of normalforms more complicated.

>From 4d9bb51c9993de82e5f68d648ca98916add9e72c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Thu, 25 Jan 2024 13:34:13 +0100
Subject: [PATCH] [mlir][python] Add normalforms to capture preconditions for
 transforms

---
 .../dialects/transform/extras/__init__.py     | 83 ++++++++++++++++-
 mlir/test/python/dialects/transform_extras.py | 89 +++++++++++++++++++
 2 files changed, 171 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8d045cad7a4a36f..ccfe4a6babb6fde 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,7 +2,7 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Callable, Optional, Sequence, Union
+from typing import Callable, Optional, Sequence, Type, TypeVar, Union
 
 from ....extras.meta import region_op
 from .... import ir
@@ -20,6 +20,8 @@
 )
 from .. import structured
 
+HandleT = TypeVar("HandleT", bound="Handle")
+
 
 class Handle(ir.Value):
     """
@@ -42,6 +44,46 @@ def __init__(
         super().__init__(v)
         self.parent = parent
         self.children = children if children is not None else []
+        self._normalform = Normalform
+
+    @property
+    def normalform(self) -> Type["Normalform"]:
+        """
+        The normalform of this handle. This is a static property of the handle
+        and indicates a group of previously applied transforms. This can be used
+        by subsequent transforms to statically reason about the structure of the
+        payload operations and whether other enabling transforms could possibly
+        be skipped.
+        Setting this property triggers propagation of the normalform to parent
+        and child handles depending on the specific normalform.
+        """
+        return self._normalform
+
+    @normalform.setter
+    def normalform(self, normalform: Type["Normalform"]):
+        self._normalform = normalform
+        if self._normalform.propagate_up:
+            self.propagate_up_normalform(normalform)
+        if self._normalform.propagate_down:
+            self.propagate_down_normalform(normalform)
+
+    def propagate_up_normalform(self, normalform: Type["Normalform"]):
+        if self.parent:
+            # We set the parent normalform directly to avoid infinite recursion
+            # in case this normalform needs to be propagated up and down.
+            self.parent._normalform = normalform
+            self.parent.propagate_up_normalform(normalform)
+
+    def propagate_down_normalform(self, normalform: Type["Normalform"]):
+        for child in self.children:
+            # We set the child normalform directly to avoid infinite recursion
+            # in case this normalform needs to be propagated up and down.
+            child._normalform = normalform
+            child.propagate_down_normalform(normalform)
+
+    def normalize(self: "HandleT", normalform: Type["Normalform"]) -> "HandleT":
+        return normalform.apply(self)
+
 
 @ir.register_value_caster(AnyOpType.get_static_typeid())
 @ir.register_value_caster(OperationType.get_static_typeid())
@@ -192,6 +234,45 @@ def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
     return op.param
 
 
+class Normalform:
+    """
+    Represents the weakest normalform and is the base class for all normalforms.
+    A normalform is defined as a sequence of transforms to be applied to a
+    handle to reach this normalform.
+
+    `propagate_up`: Propagate this normalform up to parent handles.
+    `propagate_down`: Propagate this normalform down to all child handles
+    """
+
+    propagate_up: bool = True
+    propagate_down: bool = True
+
+    def __init__(self):
+        raise TypeError(
+            "Normalform cannot be instantiated directly. Use Type[Normalform]"
+            "instead."
+        )
+
+    @classmethod
+    def _impl(cls, handle: "HandleT") -> "HandleT":
+        """
+        Defines the transforms required to reach this normalform.
+        A normalform may apply arbitrary transforms and thus possibly
+        invalidate `handle`.
+        """
+        return handle
+
+    @classmethod
+    def apply(cls, handle: "HandleT") -> "HandleT":
+        """Apply transforms to a handle to bring it into this normalform."""
+        new_handle = cls._impl(handle)
+        new_handle.children.extend(handle.children)
+        new_handle.parent = handle.parent
+        # Setting this property propagates the normalform accordingly
+        new_handle.normalform = cls
+        return new_handle
+
+
 def insert_transform_script(
     block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
     script: Callable[[OpHandle], None],
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index ea47f170cb63212..2ca7cfc5ead04f3 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -19,6 +19,7 @@
     insert_transform_script,
     sequence,
     apply_patterns,
+    Normalform,
 )
 from mlir.extras import types as T
 
@@ -55,6 +56,12 @@ def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]
         module.operation.verify()
 
 
+def run(f: Callable[[], None]):
+    print("\nTEST:", f.__name__)
+    with ir.Context(), ir.Location.unknown():
+        f()
+
+
 # CHECK-LABEL: TEST: test_build_script_at_insertion_point
 @build_transform_script_at_insertion_point
 def test_build_script_at_insertion_point(op: OpHandle):
@@ -175,6 +182,88 @@ def test_match_ops_mixed(op: OpHandle):
     # CHECK-SAME:     -> !transform.any_op
 
 
+# CHECK-LABEL: TEST: test_normalform_base
+ at build_transform_script
+def test_normalform_base(op: OpHandle):
+    # Normalform is the weakest normalform so op should already be in that form.
+    # Normalization to Normalform should be a no-op.
+    assert op._normalform is Normalform
+    op.normalize(Normalform)
+    assert op._normalform is Normalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: transform.yield
+
+
+class DummyNormalform(Normalform):
+    propagate_up: bool = True
+    propagate_down: bool = True
+
+    @classmethod
+    def _impl(cls, handle: OpHandle) -> OpHandle:
+        return handle.print("dummy normalization")
+
+
+# CHECK-LABEL: test_normalform_no_instantiation
+ at run
+def test_normalform_no_instantiation():
+    try:
+        DummyNormalform()
+    except TypeError as e:
+        print(e)
+    else:
+        print("Exception not produced")
+
+    # CHECK: Normalform cannot be instantiated directly
+
+
+# CHECK-LABEL: TEST: test_normalform_dummyform
+ at build_transform_script
+def test_normalform_dummyform(op: OpHandle):
+    op.normalize(DummyNormalform)
+    assert op._normalform is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_up
+ at build_transform_script
+def test_normalform_propagate_up(op: OpHandle):
+    nested_handle = op.match_ops("dummy.op")
+    nested_handle.normalize(DummyNormalform)
+    assert nested_handle._normalform is DummyNormalform
+    assert op._normalform is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
+    # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_down
+ at build_transform_script
+def test_normalform_propagate_down(op: OpHandle):
+    nested_handle = op.match_ops("dummy.op")
+    op.normalize(DummyNormalform)
+    assert nested_handle._normalform is DummyNormalform
+    assert op._normalform is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
+    # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_up_and_down
+ at build_transform_script
+def test_normalform_propagate_up_and_down(op: OpHandle):
+    nested_handle = op.match_ops("dummy.op1")
+    nested_nested_handle = nested_handle.match_ops("dummy.op2")
+    nested_handle.normalize(DummyNormalform)
+    assert nested_handle._normalform is DummyNormalform
+    assert op._normalform is DummyNormalform
+    assert nested_nested_handle._normalform is DummyNormalform
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op1"]}
+    # CHECK-NEXT: %[[VAL_2:.*]] = transform.structured.match ops{["dummy.op2"]}
+    # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
+
+
 # CHECK-LABEL: TEST: test_print_message
 @build_transform_script
 def test_print_message(op: OpHandle):



More information about the Mlir-commits mailing list