[Mlir-commits] [mlir] [mlir][python] meta region_op (PR #75673)

Maksim Levental llvmlistbot at llvm.org
Wed Dec 20 15:34:04 PST 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/75673

>From d76bd70fccb510a153cff4084857c4b678624be8 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 15 Dec 2023 18:21:31 -0600
Subject: [PATCH 1/7] [mlir][python] meta region_op

---
 mlir/python/CMakeLists.txt                    |   9 +-
 mlir/python/mlir/dialects/func.py             |   3 +
 mlir/python/mlir/dialects/pdl.py              |   4 +
 .../mlir/dialects/transform/__init__.py       |   8 +-
 .../dialects/transform/extras/__init__.py     |  14 ++-
 mlir/python/mlir/extras/meta.py               |  59 ++++++++++
 mlir/test/python/dialects/transform_extras.py | 101 +++++++++++++++++-
 7 files changed, 191 insertions(+), 7 deletions(-)
 create mode 100644 mlir/python/mlir/extras/meta.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 55c5973e40e525..3c9cf304d88a27 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,7 +21,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     _mlir_libs/__init__.py
     ir.py
     passmanager.py
-    extras/types.py
     dialects/_ods_common.py
 
     # The main _mlir module has submodules: include stubs from each.
@@ -30,6 +29,14 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     _mlir_libs/_mlir/passmanager.pyi
 )
 
+declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  ADD_TO_PARENT MLIRPythonSources.Core.Python
+  SOURCES
+    extras/types.py
+    extras/meta.py
+)
+
 declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   ADD_TO_PARENT MLIRPythonSources
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 6599f67b707877..24fdcbcd85b29f 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -243,6 +243,9 @@ def emit_call_op(*call_args):
         return decorator
 
 
+func = FuncOp.from_py_func
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class CallOp(CallOp):
     """Specialization for the call op class."""
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index 90d7d706238e64..de239f23d9fa96 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -220,3 +220,7 @@ def __init__(
             constantTypes = []
         result = pdl.RangeType.get(pdl.TypeType.get())
         super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
+
+
+def op_t():
+    return OperationType.get()
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 175634c7d458f1..435c1668d0d70a 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -175,7 +175,7 @@ def __init__(
         result_types: Sequence[Type],
         sym_visibility=None,
         arg_attrs=None,
-        res_attrs=None
+        res_attrs=None,
     ):
         function_type = FunctionType.get(input_types, result_types)
         super().__init__(
@@ -183,7 +183,7 @@ def __init__(
             function_type=TypeAttr.get(function_type),
             sym_visibility=sym_visibility,
             arg_attrs=arg_attrs,
-            res_attrs=res_attrs
+            res_attrs=res_attrs,
         )
         self.regions[0].blocks.append(*input_types)
 
@@ -212,3 +212,7 @@ def __init__(
         if operands is None:
             operands = []
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+
+
+def any_op_t():
+    return AnyOpType.get()
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index c715dac1ef7eb8..1e8c0652389226 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -4,8 +4,16 @@
 
 from typing import Callable, Optional, Sequence, Union
 
+from ....extras.meta import region_op
 from .... import ir
-from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
+from .. import (
+    AnyOpType,
+    OperationType,
+    NamedSequenceOp,
+    YieldOp,
+    SequenceOp,
+    ApplyPatternsOp,
+)
 from .. import structured
 
 
@@ -147,3 +155,7 @@ def test_match_ops_single(module: OpHandle):
 
     if dump_script:
         print(named_sequence_op)
+
+
+sequence = region_op(SequenceOp.__base__, terminator=YieldOp)
+apply_patterns = region_op(ApplyPatternsOp)
diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py
new file mode 100644
index 00000000000000..dce61d80eeea60
--- /dev/null
+++ b/mlir/python/mlir/extras/meta.py
@@ -0,0 +1,59 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import inspect
+from functools import wraps
+
+from ..dialects._ods_common import get_op_result_or_op_results
+from ..ir import Type, InsertionPoint
+
+
+def op_region_builder(op, op_region, terminator=None):
+    def builder_wrapper(body_builder):
+        # add a block with block args having types ...
+        if len(op_region.blocks) == 0:
+            sig = inspect.signature(body_builder)
+            types = [p.annotation for p in sig.parameters.values()]
+            if not (
+                len(types) == len(sig.parameters)
+                and all(isinstance(t, Type) for t in types)
+            ):
+                raise ValueError(
+                    f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
+                )
+
+            op_region.blocks.append(*types)
+
+        with InsertionPoint(op_region.blocks[0]):
+            results = body_builder(*list(op_region.blocks[0].arguments))
+
+        with InsertionPoint(list(op_region.blocks)[-1]):
+            if terminator is not None:
+                res = []
+                if isinstance(results, (tuple, list)):
+                    res.extend(results)
+                elif results is not None:
+                    res.append(results)
+                terminator(res)
+
+        return get_op_result_or_op_results(op)
+
+    return builder_wrapper
+
+
+def region_op(op_constructor, terminator=None):
+    def op_decorator(*args, **kwargs):
+        op = op_constructor(*args, **kwargs)
+        op_region = op.regions[0]
+
+        return op_region_builder(op, op_region, terminator)
+
+    @wraps(op_decorator)
+    def maybe_no_args(*args, **kwargs):
+        if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
+            return op_decorator()(args[0])
+        else:
+            return op_decorator(*args, **kwargs)
+
+    return maybe_no_args
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index e7b43ea63c31ca..2345e5b375c6df 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -2,9 +2,34 @@
 
 from typing import Callable
 from mlir import ir
-from mlir.dialects import scf
-from mlir.dialects.transform import structured
-from mlir.dialects.transform.extras import OpHandle, insert_transform_script
+from mlir.dialects import scf, pdl, func, arith, linalg
+from mlir.dialects.transform import (
+    structured,
+    get_parent_op,
+    apply_patterns_canonicalization,
+    apply_cse,
+    any_op_t,
+)
+from mlir.dialects.transform import FailurePropagationMode
+from mlir.dialects.transform.structured import structured_match
+from mlir.dialects.transform.loop import loop_unroll
+from mlir.dialects.transform.extras import (
+    OpHandle,
+    insert_transform_script,
+    sequence,
+    apply_patterns,
+)
+from mlir.extras import types as T
+
+
+def construct_and_print_in_module(f):
+    print("\nTEST:", f.__name__)
+    with ir.Context(), ir.Location.unknown():
+        module = ir.Module.create()
+        with ir.InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
 
 
 def build_transform_script(script: Callable[[OpHandle], None]):
@@ -93,3 +118,73 @@ def test_match_ops_mixed(op: OpHandle):
     # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
     # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
     # CHECK-SAME:     -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_sequence_region
+ at construct_and_print_in_module
+def test_sequence_region():
+    # CHECK-LABEL:   func.func @loop_unroll_op() {
+    # CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+    # CHECK:           %[[VAL_1:.*]] = arith.constant 42 : index
+    # CHECK:           %[[VAL_2:.*]] = arith.constant 5 : index
+    # CHECK:           scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.func()
+    def loop_unroll_op():
+        for i in scf.for_(0, 42, 5):
+            v = arith.addi(i, i)
+            scf.yield_([])
+
+    # CHECK:   transform.sequence  failures(propagate) {
+    # CHECK:   ^bb0(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+    # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
+    # CHECK:     transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
+    # CHECK:   }
+    @sequence([], FailurePropagationMode.Propagate, [])
+    def basic(target: any_op_t()):
+        m = structured_match(any_op_t(), target, ops=["arith.addi"])
+        loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
+        loop_unroll(loop, 4)
+
+
+# CHECK-LABEL: TEST: test_apply_patterns
+ at construct_and_print_in_module
+def test_apply_patterns():
+    M, N, K = 3, 5, 3
+
+    # CHECK-LABEL:   func.func @matmul(
+    # CHECK-SAME:                      %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
+    # CHECK:           %[[VAL_3:.*]] = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
+    # CHECK:           return %[[VAL_3]] : tensor<3x3xf32>
+    # CHECK:         }
+    @func.func(
+        T.tensor(M, N, T.f32()), T.tensor(N, K, T.f32()), T.tensor(M, K, T.f32())
+    )
+    def matmul(A, B, C):
+        return linalg.matmul(A, B, outs=[C])
+
+    # CHECK:   transform.sequence  failures(propagate) {
+    # CHECK:   ^bb0(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+    # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
+    # CHECK:     apply_patterns to %[[VAL_2]] {
+    # CHECK:       transform.apply_patterns.canonicalization
+    # CHECK:     } : !pdl.operation
+    # CHECK:     %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+    # CHECK:     apply_cse to %[[VAL_3]] : !transform.any_op
+    # CHECK:   }
+    @sequence([], FailurePropagationMode.Propagate, [])
+    def basic(variant_op: any_op_t()):
+        matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
+        top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
+
+        @apply_patterns(top_func)
+        def pats():
+            apply_patterns_canonicalization()
+
+        top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
+        apply_cse(top_func)

>From 83201ba6c14929ca10edf7956017fd3a3fa61c89 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 16:35:10 -0600
Subject: [PATCH 2/7] split tests into dialect and integration

---
 mlir/python/mlir/dialects/arith.py            |   8 +
 mlir/python/mlir/dialects/builtin.py          |  23 +++
 mlir/python/mlir/dialects/scf.py              |   2 +-
 .../dialects/transform/extras/__init__.py     |   1 +
 mlir/test/python/dialects/arith_dialect.py    |   6 +-
 mlir/test/python/dialects/transform_extras.py |  34 +---
 .../python/integration/dialects/transform.py  | 150 ++++++++++++++++++
 7 files changed, 189 insertions(+), 35 deletions(-)
 create mode 100644 mlir/test/python/integration/dialects/transform.py

diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 83aca0d58bf2ce..663a53660a6474 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -11,6 +11,8 @@
     from ._ods_common import (
         get_default_loc_context as _get_default_loc_context,
         _cext as _ods_cext,
+        get_op_result_or_op_results as _get_op_result_or_op_results,
+        SubClassValueT as _SubClassValueT,
     )
 
     from typing import Any, List, Union
@@ -75,3 +77,9 @@ def literal_value(self) -> Union[int, float]:
             return FloatAttr(self.value).value
         else:
             raise ValueError("only integer and float constants have literal values")
+
+
+def constant(
+    result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+) -> _SubClassValueT:
+    return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py
index b71cc2466d464b..1c69d6d7c3a0bd 100644
--- a/mlir/python/mlir/dialects/builtin.py
+++ b/mlir/python/mlir/dialects/builtin.py
@@ -2,8 +2,11 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from typing import Dict, Optional
+
 from ._builtin_ops_gen import *
 from ._builtin_ops_gen import _Dialect
+from ..extras.meta import region_op
 
 try:
     from ..ir import *
@@ -23,3 +26,23 @@ def __init__(self, *, loc=None, ip=None):
     @property
     def body(self):
         return self.regions[0].blocks[0]
+
+
+ at region_op
+def module(
+    *,
+    sym_name=None,
+    sym_visibility=None,
+    attrs: Optional[Dict[str, Attribute]] = None,
+    loc=None,
+    ip=None,
+):
+    mod = ModuleOp.__base__(
+        sym_name=sym_name, sym_visibility=sym_visibility, loc=loc, ip=ip
+    )
+    if attrs is None:
+        attrs = {}
+    for attr_name, attr in attrs.items():
+        mod.operation.attributes[attr_name] = attr
+
+    return mod
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 20bbed9bc93df6..dad7377987e56c 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -120,7 +120,7 @@ def for_(
     params = [start, stop, step]
     for i, p in enumerate(params):
         if isinstance(p, int):
-            p = constant(IntegerAttr.get(IndexType.get(), p))
+            p = constant(IndexType.get(), p)
         elif isinstance(p, float):
             raise ValueError(f"{p=} must be int.")
         params[i] = p
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 1e8c0652389226..e4d47e9064f2c8 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -158,4 +158,5 @@ def test_match_ops_single(module: OpHandle):
 
 
 sequence = region_op(SequenceOp.__base__, terminator=YieldOp)
+named_sequence = region_op(NamedSequenceOp, terminator=YieldOp)
 apply_patterns = region_op(ApplyPatternsOp)
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index f80f2c084a0f3b..8bb80eed2b8105 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -75,7 +75,7 @@ def __str__(self):
         f64_t = F64Type.get()
 
         with InsertionPoint(module.body):
-            a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+            a = arith.constant(f16_t, 42.42)
             # CHECK: ArithValue(%cst = arith.constant 4.240
             print(a)
 
@@ -83,12 +83,12 @@ def __str__(self):
             # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
             print(b)
 
-            a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
+            a = arith.constant(f32_t, 42.42)
             b = a - a
             # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
             print(b)
 
-            a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
+            a = arith.constant(f64_t, 42.42)
             b = a * a
             # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
             print(b)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index 2345e5b375c6df..358f8c32f75c75 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -2,7 +2,7 @@
 
 from typing import Callable
 from mlir import ir
-from mlir.dialects import scf, pdl, func, arith, linalg
+from mlir.dialects import scf, pdl
 from mlir.dialects.transform import (
     structured,
     get_parent_op,
@@ -123,23 +123,8 @@ def test_match_ops_mixed(op: OpHandle):
 # CHECK-LABEL: TEST: test_sequence_region
 @construct_and_print_in_module
 def test_sequence_region():
-    # CHECK-LABEL:   func.func @loop_unroll_op() {
-    # CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
-    # CHECK:           %[[VAL_1:.*]] = arith.constant 42 : index
-    # CHECK:           %[[VAL_2:.*]] = arith.constant 5 : index
-    # CHECK:           scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
-    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
-    # CHECK:           }
-    # CHECK:           return
-    # CHECK:         }
-    @func.func()
-    def loop_unroll_op():
-        for i in scf.for_(0, 42, 5):
-            v = arith.addi(i, i)
-            scf.yield_([])
-
     # CHECK:   transform.sequence  failures(propagate) {
-    # CHECK:   ^bb0(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK:   ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
     # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
     # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
     # CHECK:     transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
@@ -154,21 +139,8 @@ def basic(target: any_op_t()):
 # CHECK-LABEL: TEST: test_apply_patterns
 @construct_and_print_in_module
 def test_apply_patterns():
-    M, N, K = 3, 5, 3
-
-    # CHECK-LABEL:   func.func @matmul(
-    # CHECK-SAME:                      %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
-    # CHECK:           %[[VAL_3:.*]] = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
-    # CHECK:           return %[[VAL_3]] : tensor<3x3xf32>
-    # CHECK:         }
-    @func.func(
-        T.tensor(M, N, T.f32()), T.tensor(N, K, T.f32()), T.tensor(M, K, T.f32())
-    )
-    def matmul(A, B, C):
-        return linalg.matmul(A, B, outs=[C])
-
     # CHECK:   transform.sequence  failures(propagate) {
-    # CHECK:   ^bb0(%[[VAL_0:.*]]: !transform.any_op):
+    # CHECK:   ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
     # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
     # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
     # CHECK:     apply_patterns to %[[VAL_2]] {
diff --git a/mlir/test/python/integration/dialects/transform.py b/mlir/test/python/integration/dialects/transform.py
new file mode 100644
index 00000000000000..fd8736235341fa
--- /dev/null
+++ b/mlir/test/python/integration/dialects/transform.py
@@ -0,0 +1,150 @@
+from mlir.passmanager import PassManager
+from mlir.ir import Context, Location, Module, InsertionPoint, UnitAttr
+from mlir.dialects import scf, pdl, func, arith, linalg
+from mlir.dialects.transform import (
+    get_parent_op,
+    apply_patterns_canonicalization,
+    apply_cse,
+    any_op_t,
+)
+from mlir.dialects.transform.structured import structured_match
+from mlir.dialects.transform.loop import loop_unroll
+from mlir.dialects.transform.extras import named_sequence, sequence, apply_patterns
+from mlir.extras import types as T
+from mlir.dialects.builtin import module
+
+
+def construct_and_print_in_module(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            module = f(module)
+        if module is not None:
+            print(module)
+    return f
+
+
+# CHECK-LABEL: TEST: test_sequence_region
+ at construct_and_print_in_module
+def test_sequence_region(module_):
+    # CHECK-LABEL:   func.func @loop_unroll_op() {
+    # CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+    # CHECK:           %[[VAL_1:.*]] = arith.constant 42 : index
+    # CHECK:           %[[VAL_2:.*]] = arith.constant 5 : index
+    # CHECK:           scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.func()
+    def loop_unroll_op():
+        for i in scf.for_(0, 42, 5):
+            v = arith.addi(i, i)
+            scf.yield_([])
+
+    # CHECK-LABEL:   module attributes {transform.with_named_sequence} {
+    # CHECK:           transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK:             %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+    # CHECK:             %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
+    # CHECK:             transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
+    # CHECK:             transform.yield
+    # CHECK:           }
+    # CHECK:         }
+    @module(attrs={"transform.with_named_sequence": UnitAttr.get()})
+    def mod():
+        @named_sequence("__transform_main", [any_op_t()], [])
+        def basic(target: any_op_t()):
+            m = structured_match(any_op_t(), target, ops=["arith.addi"])
+            loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
+            loop_unroll(loop, 4)
+
+    print(module_)
+
+    pm = PassManager.parse("builtin.module(transform-interpreter)")
+    pm.run(module_.operation)
+
+    # CHECK-LABEL: func.func @loop_unroll_op() {
+    # CHECK:         %[[VAL_0]] = arith.constant 0 : index
+    # CHECK:         %[[VAL_1]] = arith.constant 42 : index
+    # CHECK:         %[[VAL_2]] = arith.constant 5 : index
+    # CHECK:         %[[VAL_6:.*]] = arith.constant 40 : index
+    # CHECK:         %[[VAL_7:.*]] = arith.constant 20 : index
+    # CHECK:         scf.for %[[VAL_3]] = %[[VAL_0]] to %[[VAL_6]] step %[[VAL_7]] {
+    # CHECK:           %[[VAL_5]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
+    # CHECK:           %[[VAL_9:.*]] = arith.muli %[[VAL_2]], %[[VAL_8]] : index
+    # CHECK:           %[[VAL_10:.*]] = arith.addi %[[VAL_3]], %[[VAL_9]] : index
+    # CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_10]] : index
+    # CHECK:           %[[VAL_12:.*]] = arith.constant 2 : index
+    # CHECK:           %[[VAL_13:.*]] = arith.muli %[[VAL_2]], %[[VAL_12]] : index
+    # CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_3]], %[[VAL_13]] : index
+    # CHECK:           %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index
+    # CHECK:           %[[VAL_16:.*]] = arith.constant 3 : index
+    # CHECK:           %[[VAL_17:.*]] = arith.muli %[[VAL_2]], %[[VAL_16]] : index
+    # CHECK:           %[[VAL_18:.*]] = arith.addi %[[VAL_3]], %[[VAL_17]] : index
+    # CHECK:           %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_18]] : index
+    # CHECK:         }
+    # CHECK:         %[[VAL_4]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+    # CHECK:         return
+    # CHECK:       }
+    print(module_)
+
+
+# CHECK-LABEL: TEST: test_apply_patterns
+ at construct_and_print_in_module
+def test_apply_patterns(module_):
+    M, N, K = 3, 5, 3
+
+    # CHECK-LABEL:   func.func @matmul(
+    # CHECK-SAME:                      %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
+    # CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
+    # CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : i32
+    # CHECK:           %[[VAL_5:.*]] = linalg.matmul {cast = #[[?]]<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
+    # CHECK:           return %[[VAL_5]] : tensor<3x3xf32>
+    # CHECK:         }
+    @func.func(
+        T.tensor(M, N, T.f32()), T.tensor(N, K, T.f32()), T.tensor(M, K, T.f32())
+    )
+    def matmul(A, B, C):
+        i = arith.constant(T.i32(), 1)
+        v = arith.addi(i, i)
+        return linalg.matmul(A, B, outs=[C])
+
+    # CHECK-LABEL:   module attributes {transform.with_named_sequence} {
+    # CHECK:           transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK:             %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+    # CHECK:             %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
+    # CHECK:             transform.apply_patterns to %[[VAL_2]] {
+    # CHECK:               transform.apply_patterns.canonicalization
+    # CHECK:             } : !pdl.operation
+    # CHECK:             %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+    # CHECK:             transform.apply_cse to %[[VAL_3]] : !transform.any_op
+    # CHECK:             transform.yield
+    # CHECK:           }
+    # CHECK:         }
+    @module(attrs={"transform.with_named_sequence": UnitAttr.get()})
+    def mod():
+        @named_sequence("__transform_main", [any_op_t()], [])
+        def basic(variant_op: any_op_t()):
+            matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
+            top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
+
+            @apply_patterns(top_func)
+            def pats():
+                apply_patterns_canonicalization()
+
+            top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
+            apply_cse(top_func)
+
+    print(module_)
+
+    pm = PassManager.parse("builtin.module(transform-interpreter)")
+    pm.run(module_.operation)
+
+    # CHECK-LABEL:   func.func @matmul(
+    # CHECK-SAME:                      %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
+    # CHECK:           %[[VAL_3:.*]] = linalg.matmul {cast = #[[?]]<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
+    # CHECK:           return %[[VAL_3]] : tensor<3x3xf32>
+    # CHECK:         }
+    print(module_)

>From 5003812ef69d1738204a6b7acbe3e979c70f8012 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 16:43:03 -0600
Subject: [PATCH 3/7] add NewType for any_op_t

---
 mlir/python/mlir/dialects/transform/__init__.py    |  9 ++++++---
 mlir/test/python/integration/dialects/transform.py | 11 +++++++----
 2 files changed, 13 insertions(+), 7 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 435c1668d0d70a..5b158ec6b65fdd 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -18,7 +18,7 @@
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Optional, Sequence, Union
+from typing import Optional, Sequence, Union, NewType
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)
@@ -214,5 +214,8 @@ def __init__(
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
 
 
-def any_op_t():
-    return AnyOpType.get()
+AnyOpTypeT = NewType("AnyOpType", AnyOpType)
+
+
+def any_op_t() -> AnyOpTypeT:
+    return AnyOpTypeT(AnyOpType.get())
diff --git a/mlir/test/python/integration/dialects/transform.py b/mlir/test/python/integration/dialects/transform.py
index fd8736235341fa..5ff5f7829f4893 100644
--- a/mlir/test/python/integration/dialects/transform.py
+++ b/mlir/test/python/integration/dialects/transform.py
@@ -9,9 +9,9 @@
 )
 from mlir.dialects.transform.structured import structured_match
 from mlir.dialects.transform.loop import loop_unroll
-from mlir.dialects.transform.extras import named_sequence, sequence, apply_patterns
+from mlir.dialects.transform.extras import named_sequence, apply_patterns
 from mlir.extras import types as T
-from mlir.dialects.builtin import module
+from mlir.dialects.builtin import module, ModuleOp
 
 
 def construct_and_print_in_module(f):
@@ -25,9 +25,9 @@ def construct_and_print_in_module(f):
     return f
 
 
-# CHECK-LABEL: TEST: test_sequence_region
+# CHECK-LABEL: TEST: test_named_sequence
 @construct_and_print_in_module
-def test_sequence_region(module_):
+def test_named_sequence(module_):
     # CHECK-LABEL:   func.func @loop_unroll_op() {
     # CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
     # CHECK:           %[[VAL_1:.*]] = arith.constant 42 : index
@@ -59,6 +59,9 @@ def basic(target: any_op_t()):
             loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
             loop_unroll(loop, 4)
 
+    # The identifier (name) of the function becomes the Operation
+    assert isinstance(mod.opview, ModuleOp)
+
     print(module_)
 
     pm = PassManager.parse("builtin.module(transform-interpreter)")

>From 2e50e1d63b68c20b737a89b9a60ef34c7740e58b Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 17:12:06 -0600
Subject: [PATCH 4/7] add example with result

---
 mlir/python/mlir/dialects/tensor.py |  7 ++++++
 mlir/test/python/dialects/tensor.py | 36 +++++++++++++++++++++++++++++
 2 files changed, 43 insertions(+)

diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 67248748eaf3ad..79dd9476ad0ff9 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -4,6 +4,7 @@
 
 from ._tensor_ops_gen import *
 from ._tensor_ops_gen import _Dialect
+from ..extras.meta import region_op
 
 try:
     from ..ir import *
@@ -40,3 +41,9 @@ def __init__(
                 dynamic_sizes.append(s)
         result_type = RankedTensorType.get(static_sizes, element_type)
         super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
+
+
+generate = region_op(
+    lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
+    terminator=lambda args: YieldOp(args[0]),
+)
diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py
index b690c934dc46bd..6ed77e4441a81c 100644
--- a/mlir/test/python/dialects/tensor.py
+++ b/mlir/test/python/dialects/tensor.py
@@ -4,6 +4,7 @@
 import mlir.dialects.arith as arith
 import mlir.dialects.func as func
 import mlir.dialects.tensor as tensor
+from mlir.extras import types as T
 
 
 def run(f):
@@ -139,3 +140,38 @@ def default_builder():
                 t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
                 # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
                 print(t)
+
+
+# CHECK-LABEL: TEST: testGenerateRegionOp
+ at run
+def testGenerateRegionOp():
+    S = ShapedType.get_dynamic_size()
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+
+            # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+            # CHECK: %[[VAL_1:.*]] = arith.constant 2 : index
+            one = arith.constant(T.index(), 1)
+            two = arith.constant(T.index(), 2)
+
+            @tensor.generate(T.tensor(S, 3, S, T.index()), dynamic_extents=[one, two])
+            def generate_one(i: T.index(), j: T.index(), k: T.index()):
+                ij = arith.addi(i, j)
+                ijk = arith.addi(ij, k)
+                return ijk
+
+            assert (
+                isinstance(generate_one, Value)
+                and generate_one.owner.name == "tensor.generate"
+            )
+
+        # CHECK:         %[[GENERATED:.*]] = tensor.generate
+        # CHECK-SAME:    %[[VAL_0]],
+        # CHECK-SAME:    %[[VAL_1]] {
+        # CHECK:         ^bb0(%[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
+        # CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
+        # CHECK:           %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_3]] : index
+        # CHECK:           tensor.yield %[[VAL_5]] : index
+        # CHECK:         } : tensor<?x3x?xindex>
+        print(module)

>From 7d8f5798b5457a28ad0b7159e1ee34915cb44e1c Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 17:20:14 -0600
Subject: [PATCH 5/7] add comment describing region_op

---
 mlir/python/mlir/extras/meta.py | 26 +++++++++++++++++++++++++-
 1 file changed, 25 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py
index dce61d80eeea60..3f2defadf79412 100644
--- a/mlir/python/mlir/extras/meta.py
+++ b/mlir/python/mlir/extras/meta.py
@@ -11,7 +11,7 @@
 
 def op_region_builder(op, op_region, terminator=None):
     def builder_wrapper(body_builder):
-        # add a block with block args having types ...
+        # Add a block with block args having types determined by type hints on the wrapped function.
         if len(op_region.blocks) == 0:
             sig = inspect.signature(body_builder)
             types = [p.annotation for p in sig.parameters.values()]
@@ -43,6 +43,30 @@ def builder_wrapper(body_builder):
 
 
 def region_op(op_constructor, terminator=None):
+    """Decorator to define an MLIR Op specified as a python function.
+
+    Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
+    active for the current thread (i.e. established in a `with` block).
+
+    Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor.
+
+    When applied as a decorator to a Python function, an entry block will
+    be constructed for the Op with types as specified **as type hints on the args of the function**.
+    The block arguments will be passed positionally to the Python function.
+
+    If a terminator is specified then the return from the decorated function will be passed
+    to the terminator as the last statement in the entry block. Note, the API for the terminator
+    is a (possibly empty) list; terminator accepting single values should be wrapped in a
+    `lambda args: term(args[0])`
+
+    The identifier (name) of the function will become:
+    1. A single value result if the Op returns a single value;
+    2. An OpResultList (as a list) if the Op returns multiple values;
+    3. The Operation if the Op returns no results.
+
+    See examples in tensor.py and transform.extras.
+    """
+
     def op_decorator(*args, **kwargs):
         op = op_constructor(*args, **kwargs)
         op_region = op.regions[0]

>From af16f3df5fa77af5cdb99329787272196c9ee782 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 17:28:09 -0600
Subject: [PATCH 6/7] fix formatting

---
 mlir/test/python/dialects/tensor.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py
index 6ed77e4441a81c..ca9066b239111b 100644
--- a/mlir/test/python/dialects/tensor.py
+++ b/mlir/test/python/dialects/tensor.py
@@ -149,7 +149,6 @@ def testGenerateRegionOp():
     with Context(), Location.unknown():
         module = Module.create()
         with InsertionPoint(module.body):
-
             # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
             # CHECK: %[[VAL_1:.*]] = arith.constant 2 : index
             one = arith.constant(T.index(), 1)

>From e487774e03431d228ac5d39a61591b6cd2b439c0 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 20 Dec 2023 17:33:51 -0600
Subject: [PATCH 7/7] newtype for pdl as well

---
 mlir/python/mlir/dialects/pdl.py | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index de239f23d9fa96..db07dc50aabd79 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -5,6 +5,7 @@
 from ._pdl_ops_gen import *
 from ._pdl_ops_gen import _Dialect
 from .._mlir_libs._mlirDialectsPDL import *
+from .._mlir_libs._mlirDialectsPDL import OperationType
 
 
 try:
@@ -13,7 +14,7 @@
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Union, Optional, Sequence, Mapping
+from typing import Union, Optional, Sequence, Mapping, NewType
 from ._ods_common import (
     get_op_result_or_value as _get_value,
     get_op_results_or_values as _get_values,
@@ -222,5 +223,8 @@ def __init__(
         super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
 
 
-def op_t():
-    return OperationType.get()
+OperationTypeT = NewType("OperationType", OperationType)
+
+
+def op_t() -> OperationTypeT:
+    return OperationTypeT(OperationType.get())



More information about the Mlir-commits mailing list