[Mlir-commits] [mlir] 7fd6f40 - [mlir][python] Add custom constructor for memref load

Alex Zinenko llvmlistbot at llvm.org
Wed Oct 13 08:11:08 PDT 2021


Author: Alex Zinenko
Date: 2021-10-13T17:11:02+02:00
New Revision: 7fd6f40dbd4ebd5c3378819e51a42d9d24c3dd9d

URL: https://github.com/llvm/llvm-project/commit/7fd6f40dbd4ebd5c3378819e51a42d9d24c3dd9d
DIFF: https://github.com/llvm/llvm-project/commit/7fd6f40dbd4ebd5c3378819e51a42d9d24c3dd9d.diff

LOG: [mlir][python] Add custom constructor for memref load

The type can be inferred trivially, but it is currently done as string
stitching between ODS and C++ and is not easily exposed to Python.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D111712

Added: 
    mlir/python/mlir/dialects/_memref_ops_ext.py

Modified: 
    mlir/test/python/dialects/memref.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py
new file mode 100644
index 0000000000000..cb25ef105d73f
--- /dev/null
+++ b/mlir/python/mlir/dialects/_memref_ops_ext.py
@@ -0,0 +1,37 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+  from ..ir import *
+  from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+except ImportError as e:
+  raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+class LoadOp:
+  """Specialization for the MemRef load operation."""
+
+  def __init__(self,
+               memref: Union[Operation, OpView, Value],
+               indices: Optional[Union[Operation, OpView,
+                                       Sequence[Value]]] = None,
+               *,
+               loc=None,
+               ip=None):
+    """Creates a memref load operation.
+
+    Args:
+      memref: the buffer to load from.
+      indices: the list of subscripts, may be empty for zero-dimensional
+        buffers.
+      loc: user-visible location of the operation.
+      ip: insertion point.
+    """
+    memref_resolved = _get_op_result_or_value(memref)
+    indices_resolved = [] if indices is None else _get_op_results_or_values(
+        indices)
+    return_type = memref_resolved.type
+    super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)

diff  --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 240fb9c221e9e..e421f9b2fde95 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -8,9 +8,11 @@
 def run(f):
   print("\nTEST:", f.__name__)
   f()
+  return f
 
 
 # CHECK-LABEL: TEST: testSubViewAccessors
+ at run
 def testSubViewAccessors():
   ctx = Context()
   module = Module.parse(
@@ -52,4 +54,20 @@ def testSubViewAccessors():
   print(subview.strides[1])
 
 
-run(testSubViewAccessors)
+# CHECK-LABEL: TEST: testCustomBuidlers
+ at run
+def testCustomBuidlers():
+  with Context() as ctx, Location.unknown(ctx):
+    module = Module.parse(r"""
+      func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
+        return
+      }
+    """)
+    func = module.body.operations[0]
+    func_body = func.regions[0].blocks[0]
+    with InsertionPoint.at_block_terminator(func_body):
+      memref.LoadOp(func.arguments[0], func.arguments[1:])
+
+    # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+    # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
+    print(module)


        


More information about the Mlir-commits mailing list