[Mlir-commits] [mlir] [mlir][python] meta region_op (PR #75673)

Maksim Levental llvmlistbot at llvm.org
Fri Dec 15 16:28:45 PST 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/75673

>From ef0ec7bd798a267728463ddd69741add16dd0a40 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 15 Dec 2023 18:21:31 -0600
Subject: [PATCH] [mlir][python] meta region_op

---
 mlir/python/CMakeLists.txt                    |   9 +-
 mlir/python/mlir/dialects/func.py             |   3 +
 .../extras/dialects/transform/__init__.py     |   6 +
 mlir/python/mlir/extras/meta.py               |  59 ++++++++++
 mlir/test/python/dialects/transform_extras.py | 105 +++++++++++++++++-
 5 files changed, 178 insertions(+), 4 deletions(-)
 create mode 100644 mlir/python/mlir/extras/meta.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..016031d2a1c121 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,7 +21,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     _mlir_libs/__init__.py
     ir.py
     passmanager.py
-    extras/types.py
     dialects/_ods_common.py
 
     # The main _mlir module has submodules: include stubs from each.
@@ -30,6 +29,14 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     _mlir_libs/_mlir/passmanager.pyi
 )
 
+declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  ADD_TO_PARENT MLIRPythonSources.Core.Python
+  SOURCES
+    extras/types.py
+    extras/meta.py
+)
+
 declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   ADD_TO_PARENT MLIRPythonSources
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 6599f67b707877..24fdcbcd85b29f 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -243,6 +243,9 @@ def emit_call_op(*call_args):
         return decorator
 
 
+func = FuncOp.from_py_func
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class CallOp(CallOp):
     """Specialization for the call op class."""
diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/extras/dialects/transform/__init__.py
index 9e313324318aa6..8dfe4cb06b34bd 100644
--- a/mlir/python/mlir/extras/dialects/transform/__init__.py
+++ b/mlir/python/mlir/extras/dialects/transform/__init__.py
@@ -3,8 +3,10 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from __future__ import annotations
+
 from typing import Callable, Optional, Sequence
 
+from ...meta import region_op
 from .... import ir
 from ....dialects import transform
 from ....dialects.transform import structured
@@ -146,3 +148,7 @@ def test_match_ops_single(module: OpHandle):
 
     if dump_script:
         print(named_sequence_op)
+
+
+sequence = region_op(transform.SequenceOp.__base__, terminator=transform.YieldOp)
+apply_patterns = region_op(transform.ApplyPatternsOp)
diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py
new file mode 100644
index 00000000000000..dce61d80eeea60
--- /dev/null
+++ b/mlir/python/mlir/extras/meta.py
@@ -0,0 +1,59 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import inspect
+from functools import wraps
+
+from ..dialects._ods_common import get_op_result_or_op_results
+from ..ir import Type, InsertionPoint
+
+
+def op_region_builder(op, op_region, terminator=None):
+    def builder_wrapper(body_builder):
+        # add a block with block args having types ...
+        if len(op_region.blocks) == 0:
+            sig = inspect.signature(body_builder)
+            types = [p.annotation for p in sig.parameters.values()]
+            if not (
+                len(types) == len(sig.parameters)
+                and all(isinstance(t, Type) for t in types)
+            ):
+                raise ValueError(
+                    f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
+                )
+
+            op_region.blocks.append(*types)
+
+        with InsertionPoint(op_region.blocks[0]):
+            results = body_builder(*list(op_region.blocks[0].arguments))
+
+        with InsertionPoint(list(op_region.blocks)[-1]):
+            if terminator is not None:
+                res = []
+                if isinstance(results, (tuple, list)):
+                    res.extend(results)
+                elif results is not None:
+                    res.append(results)
+                terminator(res)
+
+        return get_op_result_or_op_results(op)
+
+    return builder_wrapper
+
+
+def region_op(op_constructor, terminator=None):
+    def op_decorator(*args, **kwargs):
+        op = op_constructor(*args, **kwargs)
+        op_region = op.regions[0]
+
+        return op_region_builder(op, op_region, terminator)
+
+    @wraps(op_decorator)
+    def maybe_no_args(*args, **kwargs):
+        if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
+            return op_decorator()(args[0])
+        else:
+            return op_decorator(*args, **kwargs)
+
+    return maybe_no_args
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index dbfa8a2dc73c41..b4364c450b4d0a 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -2,9 +2,33 @@
 
 from typing import Callable
 from mlir import ir
-from mlir.dialects import scf
-from mlir.dialects.transform import structured
-from mlir.extras.dialects.transform import OpHandle, insert_transform_script
+from mlir.dialects import scf, pdl, func, arith, linalg
+from mlir.dialects.transform import (
+    structured,
+    get_parent_op,
+    apply_patterns_canonicalization,
+    apply_cse,
+)
+from mlir.dialects.transform import AnyOpType, FailurePropagationMode
+from mlir.dialects.transform.structured import structured_match
+from mlir.dialects.transform.loop import loop_unroll
+from mlir.extras.dialects.transform import (
+    OpHandle,
+    insert_transform_script,
+    sequence,
+    apply_patterns,
+)
+from mlir.extras import types as T
+
+
+def construct_and_print_in_module(f):
+    print("\nTEST:", f.__name__)
+    with ir.Context(), ir.Location.unknown():
+        module = ir.Module.create()
+        with ir.InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
 
 
 def build_transform_script(script: Callable[[OpHandle], None]):
@@ -93,3 +117,78 @@ def test_match_ops_mixed(op: OpHandle):
     # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
     # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
     # CHECK-SAME:     -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_sequence_region
+ at construct_and_print_in_module
+def test_sequence_region():
+    any_op_t = AnyOpType.get()
+    pdl_t = pdl.OperationType.get()
+
+    # CHECK-LABEL:   func.func @loop_unroll_op() {
+    # CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+    # CHECK:           %[[VAL_1:.*]] = arith.constant 42 : index
+    # CHECK:           %[[VAL_2:.*]] = arith.constant 5 : index
+    # CHECK:           scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.func()
+    def loop_unroll_op():
+        for i in scf.for_(0, 42, 5):
+            v = arith.addi(i, i)
+            scf.yield_([])
+
+    # CHECK:   transform.sequence  failures(propagate) {
+    # CHECK:   ^bb0(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+    # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
+    # CHECK:     transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
+    # CHECK:   }
+    @sequence([], FailurePropagationMode.Propagate, [])
+    def basic(target: any_op_t):
+        m = structured_match(any_op_t, target, ops=["arith.addi"])
+        loop = get_parent_op(pdl_t, m, op_name="scf.for")
+        loop_unroll(loop, 4)
+
+
+# CHECK-LABEL: TEST: test_apply_patterns
+ at construct_and_print_in_module
+def test_apply_patterns():
+    any_op_t = AnyOpType.get()
+    pdl_t = pdl.OperationType.get()
+    M, N, K = 3, 5, 3
+
+    # CHECK-LABEL:   func.func @matmul(
+    # CHECK-SAME:                      %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
+    # CHECK:           %[[VAL_3:.*]] = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
+    # CHECK:           return %[[VAL_3]] : tensor<3x3xf32>
+    # CHECK:         }
+    @func.func(
+        T.tensor(M, N, T.f32()), T.tensor(N, K, T.f32()), T.tensor(M, K, T.f32())
+    )
+    def matmul(A, B, C):
+        return linalg.matmul(A, B, outs=[C])
+
+    # CHECK:   transform.sequence  failures(propagate) {
+    # CHECK:   ^bb0(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+    # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
+    # CHECK:     apply_patterns to %[[VAL_2]] {
+    # CHECK:       transform.apply_patterns.canonicalization
+    # CHECK:     } : !pdl.operation
+    # CHECK:     %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+    # CHECK:     apply_cse to %[[VAL_3]] : !transform.any_op
+    # CHECK:   }
+    @sequence([], FailurePropagationMode.Propagate, [])
+    def basic(variant_op: any_op_t):
+        matmul = structured_match(any_op_t, variant_op, ops=["linalg.matmul"])
+        top_func = get_parent_op(pdl_t, matmul, op_name="func.func")
+
+        @apply_patterns(top_func)
+        def pats():
+            apply_patterns_canonicalization()
+
+        top_func = structured_match(any_op_t, variant_op, ops=["func.func"])
+        apply_cse(top_func)



More information about the Mlir-commits mailing list