[Mlir-commits] [mlir] 7fd6f40 - [mlir][python] Add custom constructor for memref load
Alex Zinenko
llvmlistbot at llvm.org
Wed Oct 13 08:11:08 PDT 2021
Author: Alex Zinenko
Date: 2021-10-13T17:11:02+02:00
New Revision: 7fd6f40dbd4ebd5c3378819e51a42d9d24c3dd9d
URL: https://github.com/llvm/llvm-project/commit/7fd6f40dbd4ebd5c3378819e51a42d9d24c3dd9d
DIFF: https://github.com/llvm/llvm-project/commit/7fd6f40dbd4ebd5c3378819e51a42d9d24c3dd9d.diff
LOG: [mlir][python] Add custom constructor for memref load
The type can be inferred trivially, but it is currently done as string
stitching between ODS and C++ and is not easily exposed to Python.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D111712
Added:
mlir/python/mlir/dialects/_memref_ops_ext.py
Modified:
mlir/test/python/dialects/memref.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py
new file mode 100644
index 0000000000000..cb25ef105d73f
--- /dev/null
+++ b/mlir/python/mlir/dialects/_memref_ops_ext.py
@@ -0,0 +1,37 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+ from ..ir import *
+ from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+class LoadOp:
+ """Specialization for the MemRef load operation."""
+
+ def __init__(self,
+ memref: Union[Operation, OpView, Value],
+ indices: Optional[Union[Operation, OpView,
+ Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None):
+ """Creates a memref load operation.
+
+ Args:
+ memref: the buffer to load from.
+ indices: the list of subscripts, may be empty for zero-dimensional
+ buffers.
+ loc: user-visible location of the operation.
+ ip: insertion point.
+ """
+ memref_resolved = _get_op_result_or_value(memref)
+ indices_resolved = [] if indices is None else _get_op_results_or_values(
+ indices)
+ return_type = memref_resolved.type
+ super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 240fb9c221e9e..e421f9b2fde95 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -8,9 +8,11 @@
def run(f):
print("\nTEST:", f.__name__)
f()
+ return f
# CHECK-LABEL: TEST: testSubViewAccessors
+ at run
def testSubViewAccessors():
ctx = Context()
module = Module.parse(
@@ -52,4 +54,20 @@ def testSubViewAccessors():
print(subview.strides[1])
-run(testSubViewAccessors)
+# CHECK-LABEL: TEST: testCustomBuidlers
+ at run
+def testCustomBuidlers():
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.parse(r"""
+ func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
+ return
+ }
+ """)
+ func = module.body.operations[0]
+ func_body = func.regions[0].blocks[0]
+ with InsertionPoint.at_block_terminator(func_body):
+ memref.LoadOp(func.arguments[0], func.arguments[1:])
+
+ # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+ # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
+ print(module)
More information about the Mlir-commits
mailing list