[Mlir-commits] [mlir] [mlir][python] Add normalforms to capture preconditions of transforms (PR #79449)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 26 03:28:11 PST 2024
https://github.com/martin-luecke updated https://github.com/llvm/llvm-project/pull/79449
>From 2d1cbe436bfb9b1991d2360849441c128efd6f01 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Lu=CC=88cke?= <martin.luecke at ed.ac.uk>
Date: Fri, 26 Jan 2024 12:27:55 +0100
Subject: [PATCH] [mlir][python] Add normalforms to capture preconditions for
transforms
---
.../dialects/transform/extras/__init__.py | 83 ++++++++++++++++-
mlir/test/python/dialects/transform_extras.py | 89 +++++++++++++++++++
2 files changed, 171 insertions(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8d045cad7a4a36f..8f7d04ddb1f9e54 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Callable, Optional, Sequence, Union
+from typing import Callable, Optional, Sequence, Type, TypeVar, Union
from ....extras.meta import region_op
from .... import ir
@@ -20,6 +20,8 @@
)
from .. import structured
+HandleT = TypeVar("HandleT", bound="Handle")
+
class Handle(ir.Value):
"""
@@ -42,6 +44,46 @@ def __init__(
super().__init__(v)
self.parent = parent
self.children = children if children is not None else []
+ self._normalForm = NormalForm
+
+ @property
+ def normalForm(self) -> Type["NormalForm"]:
+ """
+ The normalform of this handle. This is a static property of the handle
+ and indicates a group of previously applied transforms. This can be used
+ by subsequent transforms to statically reason about the structure of the
+ payload operations and whether other enabling transforms could possibly
+ be skipped.
+ Setting this property triggers propagation of the normalform to parent
+ and child handles depending on the specific normalform.
+ """
+ return self._normalForm
+
+ @normalForm.setter
+ def normalForm(self, normalForm: Type["NormalForm"]):
+ self._normalForm = normalForm
+ if self._normalForm.propagate_up:
+ self.propagate_up_normalform(normalForm)
+ if self._normalForm.propagate_down:
+ self.propagate_down_normalform(normalForm)
+
+ def propagate_up_normalform(self, normalForm: Type["NormalForm"]):
+ if self.parent:
+ # We set the parent normalform directly to avoid infinite recursion
+ # in case this normalform needs to be propagated up and down.
+ self.parent._normalForm = normalForm
+ self.parent.propagate_up_normalform(normalForm)
+
+ def propagate_down_normalform(self, normalForm: Type["NormalForm"]):
+ for child in self.children:
+ # We set the child normalform directly to avoid infinite recursion
+ # in case this normalform needs to be propagated up and down.
+ child._normalForm = normalForm
+ child.propagate_down_normalform(normalForm)
+
+ def normalize(self: HandleT, normalForm: Type["NormalForm"]) -> HandleT:
+ return normalForm.apply(self)
+
@ir.register_value_caster(AnyOpType.get_static_typeid())
@ir.register_value_caster(OperationType.get_static_typeid())
@@ -192,6 +234,45 @@ def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle:
return op.param
+class NormalForm:
+ """
+ Represents the weakest normalform and is the base class for all normalforms.
+ A normalform is defined as a sequence of transforms to be applied to a
+ handle to reach this normalform.
+
+ `propagate_up`: Propagate this normalform up to parent handles.
+ `propagate_down`: Propagate this normalform down to all child handles
+ """
+
+ propagate_up: bool = True
+ propagate_down: bool = True
+
+ def __init__(self):
+ raise TypeError(
+ "NormalForm cannot be instantiated directly. Use Type[NormalForm]"
+ "instead."
+ )
+
+ @classmethod
+ def _impl(cls, handle: HandleT) -> HandleT:
+ """
+ Defines the transforms required to reach this normalform.
+ A normalform may apply arbitrary transforms and thus possibly
+ invalidate `handle`.
+ """
+ return handle
+
+ @classmethod
+ def apply(cls, handle: HandleT) -> HandleT:
+ """Apply transforms to a handle to bring it into this normalform."""
+ new_handle = cls._impl(handle)
+ new_handle.children.extend(handle.children)
+ new_handle.parent = handle.parent
+ # Setting this property propagates the normalform accordingly
+ new_handle.normalForm = cls
+ return new_handle
+
+
def insert_transform_script(
block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
script: Callable[[OpHandle], None],
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index ea47f170cb63212..b329609f795b52d 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -19,6 +19,7 @@
insert_transform_script,
sequence,
apply_patterns,
+ NormalForm,
)
from mlir.extras import types as T
@@ -55,6 +56,12 @@ def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]
module.operation.verify()
+def run(f: Callable[[], None]):
+ print("\nTEST:", f.__name__)
+ with ir.Context(), ir.Location.unknown():
+ f()
+
+
# CHECK-LABEL: TEST: test_build_script_at_insertion_point
@build_transform_script_at_insertion_point
def test_build_script_at_insertion_point(op: OpHandle):
@@ -175,6 +182,88 @@ def test_match_ops_mixed(op: OpHandle):
# CHECK-SAME: -> !transform.any_op
+# CHECK-LABEL: TEST: test_normalform_base
+ at build_transform_script
+def test_normalform_base(op: OpHandle):
+ # Normalform is the weakest normalform so op should already be in that form.
+ # Normalization to Normalform should be a no-op.
+ assert op._normalForm is NormalForm
+ op.normalize(NormalForm)
+ assert op._normalForm is NormalForm
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: transform.yield
+
+
+class DummyNormalform(NormalForm):
+ propagate_up: bool = True
+ propagate_down: bool = True
+
+ @classmethod
+ def _impl(cls, handle: OpHandle) -> OpHandle:
+ return handle.print("dummy normalization")
+
+
+# CHECK-LABEL: test_normalform_no_instantiation
+ at run
+def test_normalform_no_instantiation():
+ try:
+ DummyNormalform()
+ except TypeError as e:
+ print(e)
+ else:
+ print("Exception not produced")
+
+ # CHECK: NormalForm cannot be instantiated directly
+
+
+# CHECK-LABEL: TEST: test_normalform_dummyform
+ at build_transform_script
+def test_normalform_dummyform(op: OpHandle):
+ op.normalize(DummyNormalform)
+ assert op._normalForm is DummyNormalform
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_up
+ at build_transform_script
+def test_normalform_propagate_up(op: OpHandle):
+ nested_handle = op.match_ops("dummy.op")
+ nested_handle.normalize(DummyNormalform)
+ assert nested_handle._normalForm is DummyNormalform
+ assert op._normalForm is DummyNormalform
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
+ # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_down
+ at build_transform_script
+def test_normalform_propagate_down(op: OpHandle):
+ nested_handle = op.match_ops("dummy.op")
+ op.normalize(DummyNormalform)
+ assert nested_handle._normalForm is DummyNormalform
+ assert op._normalForm is DummyNormalform
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op"]}
+ # CHECK-NEXT: transform.print %[[VAL_0]] {name = "dummy normalization"}
+
+
+# CHECK-LABEL: TEST: test_normalform_propagate_up_and_down
+ at build_transform_script
+def test_normalform_propagate_up_and_down(op: OpHandle):
+ nested_handle = op.match_ops("dummy.op1")
+ nested_nested_handle = nested_handle.match_ops("dummy.op2")
+ nested_handle.normalize(DummyNormalform)
+ assert nested_handle._normalForm is DummyNormalform
+ assert op._normalForm is DummyNormalform
+ assert nested_nested_handle._normalForm is DummyNormalform
+ # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["dummy.op1"]}
+ # CHECK-NEXT: %[[VAL_2:.*]] = transform.structured.match ops{["dummy.op2"]}
+ # CHECK-NEXT: transform.print %[[VAL_1]] {name = "dummy normalization"}
+
+
# CHECK-LABEL: TEST: test_print_message
@build_transform_script
def test_print_message(op: OpHandle):
More information about the Mlir-commits
mailing list