[Mlir-commits] [mlir] a13c715 - [mlir][transform][bufferization][python] Add mix-in classes for two ops.

Ingo Müller llvmlistbot at llvm.org
Wed Jul 26 11:00:17 PDT 2023


Author: Ingo Müller
Date: 2023-07-26T18:00:12Z
New Revision: a13c715aae10fc548356b6b99075a71409fd14be

URL: https://github.com/llvm/llvm-project/commit/a13c715aae10fc548356b6b99075a71409fd14be
DIFF: https://github.com/llvm/llvm-project/commit/a13c715aae10fc548356b6b99075a71409fd14be.diff

LOG: [mlir][transform][bufferization][python] Add mix-in classes for two ops.

This patch adds mix-in classes for the Python bindings of
`EmptyTensorToAllocTensorOp` and `OneShotBufferizeOp`. For both classes,
the mix-in add overloads to the `__init__` functions that allow to
construct them without providing the return type, which is defaulted to
the only allowed type and `AnyOpType`, respectively.

Note that the mix-in do not expose the
`function_boundary_type_conversion` attribute. The attribute has a
custom type from the bufferization dialect that is currently not exposed
in the Python bindings. Handling of that attribute can be added easily
to the mix-in class when the need arises.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D155799

Added: 
    mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
    mlir/test/python/dialects/transform_bufferization_ext.py

Modified: 
    mlir/python/CMakeLists.txt
    utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 50fbca38a08fb4..05a36cafc26d07 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -139,6 +139,7 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/BufferizationTransformOps.td
   SOURCES
+    dialects/_bufferization_transform_ops_ext.py
     dialects/transform/bufferization.py
   DIALECT_NAME transform
   EXTENSION_NAME bufferization_transform)

diff  --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
new file mode 100644
index 00000000000000..77f4d1e1608c83
--- /dev/null
+++ b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
@@ -0,0 +1,114 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+    from ..ir import *
+    from ..dialects import transform
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, overload, Union
+
+
+class EmptyTensorToAllocTensorOp:
+    """Specialization for EmptyTensorToAllocTensorOp class."""
+
+    @overload
+    def __init__(
+        self,
+        transformed_type: Type,
+        target: Union[Operation, OpView, Value],
+        *,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    @overload
+    def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
+        ...
+
+    def __init__(
+        self,
+        transformed_type_or_target: Type,
+        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        *,
+        loc=None,
+        ip=None
+    ):
+        if isinstance(transformed_type_or_target, Type):
+            transformed_type = transformed_type_or_target
+            target = target_or_none
+        else:
+            transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
+            target = transformed_type_or_target
+
+        super().__init__(
+            transformed_type,
+            target,
+            loc=loc,
+            ip=ip,
+        )
+
+
+class OneShotBufferizeOp:
+    """Specialization for OneShotBufferizeOp class."""
+
+    @overload
+    def __init__(
+        self,
+        transformed_type: Type,
+        target: Union[Operation, OpView, Value],
+        *,
+        allow_return_allocs: Optional[bool] = None,
+        allow_unknown_ops: Optional[bool] = None,
+        bufferize_function_boundaries: Optional[bool] = None,
+        create_deallocs: Optional[bool] = None,
+        test_analysis_only: Optional[bool] = None,
+        print_conflicts: Optional[bool] = None,
+        memcpy_op: Optional[str] = None,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    @overload
+    def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
+        ...
+
+    def __init__(
+        self,
+        transformed_type_or_target: Type,
+        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        *,
+        allow_return_allocs: Optional[bool] = None,
+        allow_unknown_ops: Optional[bool] = None,
+        bufferize_function_boundaries: Optional[bool] = None,
+        create_deallocs: Optional[bool] = None,
+        test_analysis_only: Optional[bool] = None,
+        print_conflicts: Optional[bool] = None,
+        memcpy_op: Optional[str] = None,
+        loc=None,
+        ip=None
+    ):
+        if isinstance(transformed_type_or_target, Type):
+            transformed_type = transformed_type_or_target
+            target = target_or_none
+        else:
+            transformed_type = transform.AnyOpType.get()
+            target = transformed_type_or_target
+
+        super().__init__(
+            transformed_type,
+            target,
+            allow_return_allocs=allow_return_allocs,
+            allow_unknown_ops=allow_unknown_ops,
+            bufferize_function_boundaries=bufferize_function_boundaries,
+            create_deallocs=create_deallocs,
+            test_analysis_only=test_analysis_only,
+            print_conflicts=print_conflicts,
+            memcpy_op=memcpy_op,
+            loc=loc,
+            ip=ip,
+        )

diff  --git a/mlir/test/python/dialects/transform_bufferization_ext.py b/mlir/test/python/dialects/transform_bufferization_ext.py
new file mode 100644
index 00000000000000..b2e213f78c8cdf
--- /dev/null
+++ b/mlir/test/python/dialects/transform_bufferization_ext.py
@@ -0,0 +1,104 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import bufferization
+
+
+def run(f):
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
+
+
+ at run
+def testEmptyTensorToAllocTensorOpCompact():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("tensor.empty"),
+    )
+    with InsertionPoint(sequence.body):
+        bufferization.EmptyTensorToAllocTensorOp(sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testEmptyTensorToAllocTensorOpCompact
+    # CHECK: = transform.bufferization.empty_tensor_to_alloc_tensor
+    # CHECK-SAME: (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor">
+
+
+ at run
+def testEmptyTensorToAllocTensorOpTyped():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("tensor.empty"),
+    )
+    with InsertionPoint(sequence.body):
+        bufferization.EmptyTensorToAllocTensorOp(
+            transform.OperationType.get("bufferization.alloc_tensor"),
+            sequence.bodyTarget,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testEmptyTensorToAllocTensorOpTyped
+    # CHECK: = transform.bufferization.empty_tensor_to_alloc_tensor
+    # CHECK-SAME: (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor">
+
+
+ at run
+def testOneShotBufferizeOpCompact():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        bufferization.OneShotBufferizeOp(sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testOneShotBufferizeOpCompact
+    # CHECK: = transform.bufferization.one_shot_bufferize
+    # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
+ at run
+def testOneShotBufferizeOpTyped():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        bufferization.OneShotBufferizeOp(
+            transform.OperationType.get("test.dummy"),
+            sequence.bodyTarget,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testOneShotBufferizeOpTyped
+    # CHECK: = transform.bufferization.one_shot_bufferize
+    # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy">
+
+
+ at run
+def testOneShotBufferizeOpAttributes():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        bufferization.OneShotBufferizeOp(
+            sequence.bodyTarget,
+            allow_return_allocs=True,
+            allow_unknown_ops=True,
+            bufferize_function_boundaries=True,
+            create_deallocs=True,
+            test_analysis_only=True,
+            print_conflicts=True,
+            memcpy_op="memref.copy",
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testOneShotBufferizeOpAttributes
+    # CHECK: = transform.bufferization.one_shot_bufferize
+    # CHECK-SAME: allow_return_allocs = true
+    # CHECK-SAME: allow_unknown_ops = true
+    # CHECK-SAME: bufferize_function_boundaries = true
+    # CHECK-SAME: print_conflicts = true
+    # CHECK-SAME: test_analysis_only = true
+    # CHECK-SAME: (!transform.any_op) -> !transform.any_op

diff  --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index f01b540f3650b0..7e706c6e773cf8 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -753,6 +753,25 @@ gentbl_filegroup(
     ],
 )
 
+gentbl_filegroup(
+    name = "BufferizationTransformOpsPyGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-python-op-bindings",
+                "-bind-dialect=transform",
+                "-dialect-extension=bufferization_transform",
+            ],
+            "mlir/dialects/_bufferization_transform_ops_gen.py",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "mlir/dialects/BufferizationTransformOps.td",
+    deps = [
+        "//mlir:BufferizationTransformOpsTdFiles",
+    ],
+)
+
 gentbl_filegroup(
     name = "GPUTransformOpsPyGen",
     tbl_outs = [
@@ -776,7 +795,6 @@ gentbl_filegroup(
     ],
 )
 
-
 gentbl_filegroup(
     name = "StructuredTransformOpsPyGen",
     tbl_outs = [
@@ -849,11 +867,13 @@ gentbl_filegroup(
 filegroup(
     name = "TransformOpsPyFiles",
     srcs = [
+        "mlir/dialects/_bufferization_transform_ops_ext.py",
         "mlir/dialects/_gpu_transform_ops_ext.py",
         "mlir/dialects/_loop_transform_ops_ext.py",
         "mlir/dialects/_structured_transform_ops_ext.py",
         "mlir/dialects/_transform_ops_ext.py",
         "mlir/dialects/_transform_pdl_extension_ops_ext.py",
+        ":BufferizationTransformOpsPyGen",
         ":GPUTransformOpsPyGen",
         ":LoopTransformOpsPyGen",
         ":PDLTransformOpsPyGen",


        


More information about the Mlir-commits mailing list