[Mlir-commits] [mlir] [mlir][python] Add normalforms to capture preconditions of transforms (PR #79449)

Maksim Levental llvmlistbot at llvm.org
Thu Jan 25 11:48:01 PST 2024


================
@@ -42,6 +44,46 @@ def __init__(
         super().__init__(v)
         self.parent = parent
         self.children = children if children is not None else []
+        self._normalform = Normalform
+
+    @property
+    def normalform(self) -> Type["Normalform"]:
+        """
+        The normalform of this handle. This is a static property of the handle
+        and indicates a group of previously applied transforms. This can be used
+        by subsequent transforms to statically reason about the structure of the
+        payload operations and whether other enabling transforms could possibly
+        be skipped.
+        Setting this property triggers propagation of the normalform to parent
+        and child handles depending on the specific normalform.
+        """
+        return self._normalform
+
+    @normalform.setter
+    def normalform(self, normalform: Type["Normalform"]):
+        self._normalform = normalform
+        if self._normalform.propagate_up:
+            self.propagate_up_normalform(normalform)
+        if self._normalform.propagate_down:
+            self.propagate_down_normalform(normalform)
+
+    def propagate_up_normalform(self, normalform: Type["Normalform"]):
+        if self.parent:
+            # We set the parent normalform directly to avoid infinite recursion
+            # in case this normalform needs to be propagated up and down.
+            self.parent._normalform = normalform
+            self.parent.propagate_up_normalform(normalform)
+
+    def propagate_down_normalform(self, normalform: Type["Normalform"]):
+        for child in self.children:
+            # We set the child normalform directly to avoid infinite recursion
+            # in case this normalform needs to be propagated up and down.
+            child._normalform = normalform
+            child.propagate_down_normalform(normalform)
+
+    def normalize(self: "HandleT", normalform: Type["Normalform"]) -> "HandleT":
----------------
makslevental wrote:

`HandleT` is defined above so you can just use it here?

https://github.com/llvm/llvm-project/pull/79449


More information about the Mlir-commits mailing list