[Mlir-commits] [mlir] ccd7f0f - [mlir][memref][transform][python] Create mix-in for MemRefMultiBufferOp.

Ingo Müller llvmlistbot at llvm.org
Tue Aug 1 00:56:45 PDT 2023


Author: Ingo Müller
Date: 2023-08-01T07:56:40Z
New Revision: ccd7f0f1c38966b3e5e0a231e27152f67c6b7dc8

URL: https://github.com/llvm/llvm-project/commit/ccd7f0f1c38966b3e5e0a231e27152f67c6b7dc8
DIFF: https://github.com/llvm/llvm-project/commit/ccd7f0f1c38966b3e5e0a231e27152f67c6b7dc8.diff

LOG: [mlir][memref][transform][python] Create mix-in for MemRefMultiBufferOp.

Create a mix-in class with an overloaded constructor that makes the
return type optional.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D156561

Added: 
    mlir/python/mlir/dialects/_memref_transform_ops_ext.py
    mlir/test/python/dialects/transform_memref_ext.py

Modified: 
    mlir/python/CMakeLists.txt
    utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index d9c1a98bca88bc..a2aa493e2d827b 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -178,6 +178,7 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/MemRefTransformOps.td
   SOURCES
+    dialects/_memref_transform_ops_ext.py
     dialects/transform/memref.py
   DIALECT_NAME transform
   EXTENSION_NAME memref_transform)

diff  --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
new file mode 100644
index 00000000000000..4afe8e7b887f68
--- /dev/null
+++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
@@ -0,0 +1,68 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+    from ..ir import *
+    from ..dialects import transform
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, overload, Union
+
+
+class MemRefMultiBufferOp:
+    """Specialization for MemRefMultiBufferOp class."""
+
+    @overload
+    def __init__(
+        self,
+        transformed_type: Type,
+        target: Union[Operation, OpView, Value],
+        factor: Union[int, IntegerAttr],
+        *,
+        skip_analysis: Optional[bool] = None,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        target: Union[Operation, OpView, Value],
+        factor: Union[int, IntegerAttr],
+        *,
+        skip_analysis: Optional[bool] = None,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    def __init__(
+        self,
+        transformed_type_or_target: Type,
+        target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
+        factor_or_none: Optional[Union[int, IntegerAttr]] = None,
+        *,
+        skip_analysis: Optional[bool] = None,
+        loc=None,
+        ip=None
+    ):
+        if isinstance(transformed_type_or_target, Type):
+            transformed_type = transformed_type_or_target
+            target = target_or_factor
+            factor = factor_or_none
+        else:
+            transformed_type = transform.AnyOpType.get()
+            target = transformed_type_or_target
+            factor = target_or_factor
+
+        super().__init__(
+            transformed_type,
+            target,
+            factor,
+            skip_analysis=skip_analysis,
+            loc=loc,
+            ip=ip,
+        )

diff  --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py
new file mode 100644
index 00000000000000..f130fbd829a997
--- /dev/null
+++ b/mlir/test/python/dialects/transform_memref_ext.py
@@ -0,0 +1,67 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import memref
+
+
+def run(f):
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
+
+
+ at run
+def testMemRefMultiBufferOpCompact():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("memref.alloc"),
+    )
+    with InsertionPoint(sequence.body):
+        memref.MemRefMultiBufferOp(sequence.bodyTarget, 4)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefMultiBufferOpCompact
+    # CHECK: = transform.memref.multibuffer
+    # CHECK-SAME: factor = 4 : i64
+    # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op
+
+
+ at run
+def testMemRefMultiBufferOpTyped():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("memref.alloc"),
+    )
+    with InsertionPoint(sequence.body):
+        memref.MemRefMultiBufferOp(
+            transform.OperationType.get("memref.alloc"), sequence.bodyTarget, 4
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefMultiBufferOpTyped
+    # CHECK: = transform.memref.multibuffer
+    # CHECK-SAME: factor = 4 : i64
+    # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.op<"memref.alloc">
+
+
+ at run
+def testMemRefMultiBufferOpAttributes():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("memref.alloc"),
+    )
+    with InsertionPoint(sequence.body):
+        memref.MemRefMultiBufferOp(sequence.bodyTarget, 4, skip_analysis=True)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefMultiBufferOpAttributes
+    # CHECK: = transform.memref.multibuffer
+    # CHECK-SAME: factor = 4 : i64
+    # CHECK-SAME: skip_analysis
+    # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op

diff  --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index 2db3444a9acff7..3942aa152032ba 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -947,6 +947,7 @@ filegroup(
         "mlir/dialects/_bufferization_transform_ops_ext.py",
         "mlir/dialects/_gpu_transform_ops_ext.py",
         "mlir/dialects/_loop_transform_ops_ext.py",
+        "mlir/dialects/_memref_transform_ops_ext.py",
         "mlir/dialects/_structured_transform_ops_ext.py",
         "mlir/dialects/_transform_ops_ext.py",
         "mlir/dialects/_transform_pdl_extension_ops_ext.py",


        


More information about the Mlir-commits mailing list