[llvm-branch-commits] [mlir] 894d88a - [mlir][python] Add facility for extending generated python ODS.
Stella Laurenzo via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 19 13:25:34 PST 2021
Author: Stella Laurenzo
Date: 2021-01-19T13:20:26-08:00
New Revision: 894d88a759c9376de4a48ed99c965aac97839b6c
URL: https://github.com/llvm/llvm-project/commit/894d88a759c9376de4a48ed99c965aac97839b6c
DIFF: https://github.com/llvm/llvm-project/commit/894d88a759c9376de4a48ed99c965aac97839b6c.diff
LOG: [mlir][python] Add facility for extending generated python ODS.
* This isn't exclusive with other mechanisms for more ODS centric op definitions, but based on discussions, we feel that we will always benefit from a python escape hatch, and that is the most natural way to write things that don't fit the mold.
* I suspect this facility needs further tweaking, and once it settles, I'll document it and add more tests.
* Added extensions for linalg, since it is unusable without them and continued to evolve my e2e example.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D94752
Added:
mlir/examples/python/.style.yapf
mlir/lib/Bindings/Python/.style.yapf
mlir/lib/Bindings/Python/mlir/dialects/_linalg.py
Modified:
mlir/examples/python/linalg_matmul.py
mlir/lib/Bindings/Python/CMakeLists.txt
mlir/lib/Bindings/Python/mlir/dialects/__init__.py
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/examples/python/.style.yapf b/mlir/examples/python/.style.yapf
new file mode 100644
index 000000000000..9ef1dc15ba62
--- /dev/null
+++ b/mlir/examples/python/.style.yapf
@@ -0,0 +1,4 @@
+[style]
+ based_on_style = google
+ column_limit = 80
+ indent_width = 2
diff --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py
index 83dc15eda9b6..e9be189bfaaf 100644
--- a/mlir/examples/python/linalg_matmul.py
+++ b/mlir/examples/python/linalg_matmul.py
@@ -15,59 +15,69 @@
# TODO: This should be in the core API.
def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
- """Creates a |func| op.
+ """Creates a |func| op.
TODO: This should really be in the MLIR API.
Returns:
(operation, entry_block)
"""
- attrs = {
- "type": TypeAttr.get(func_type),
- "sym_name": StringAttr.get(name),
- }
- op = Operation.create("func", regions=1, attributes=attrs)
- body_region = op.regions[0]
- entry_block = body_region.blocks.append(*func_type.inputs)
- return op, entry_block
+ attrs = {
+ "type": TypeAttr.get(func_type),
+ "sym_name": StringAttr.get(name),
+ }
+ op = Operation.create("func", regions=1, attributes=attrs)
+ body_region = op.regions[0]
+ entry_block = body_region.blocks.append(*func_type.inputs)
+ return op, entry_block
-# TODO: Generate customs builder vs patching one in.
-def PatchMatmulOpInit(self, lhs, rhs, result, loc=None, ip=None):
- super(linalg.MatmulOp, self).__init__(
- self._ods_build_default(operands=[[lhs, rhs], [result]],
- results=[],
- loc=loc,
- ip=ip))
+def build_matmul_buffers_func(func_name, m, k, n, dtype):
+ lhs_type = MemRefType.get(dtype, [m, k])
+ rhs_type = MemRefType.get(dtype, [k, n])
+ result_type = MemRefType.get(dtype, [m, n])
+ # TODO: There should be a one-liner for this.
+ func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
+ _, entry = FuncOp(func_name, func_type)
+ lhs, rhs, result = entry.arguments
+ with InsertionPoint(entry):
+ op = linalg.MatmulOp([lhs, rhs], [result])
# TODO: Implement support for SingleBlockImplicitTerminator
- block = self.regions[0].blocks.append()
+ block = op.regions[0].blocks.append()
with InsertionPoint(block):
linalg.YieldOp(values=[])
-linalg.MatmulOp.__init__ = PatchMatmulOpInit
+ std.ReturnOp([])
-def build_matmul_func(func_name, m, k, n, dtype):
- lhs_type = MemRefType.get(dtype, [m, k])
- rhs_type = MemRefType.get(dtype, [k, n])
- result_type = MemRefType.get(dtype, [m, n])
- # TODO: There should be a one-liner for this.
- func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
- _, entry = FuncOp(func_name, func_type)
- lhs, rhs, result = entry.arguments
- with InsertionPoint(entry):
- linalg.MatmulOp(lhs, rhs, result)
- std.ReturnOp([])
+def build_matmul_tensors_func(func_name, m, k, n, dtype):
+ # TODO: MemRefType and TensorTypes should not have inverted dtype/shapes
+ # from each other.
+ lhs_type = RankedTensorType.get([m, k], dtype)
+ rhs_type = RankedTensorType.get([k, n], dtype)
+ result_type = RankedTensorType.get([m, n], dtype)
+ # TODO: There should be a one-liner for this.
+ func_type = FunctionType.get([lhs_type, rhs_type], [result_type])
+ _, entry = FuncOp(func_name, func_type)
+ lhs, rhs = entry.arguments
+ with InsertionPoint(entry):
+ op = linalg.MatmulOp([lhs, rhs], results=[result_type])
+ # TODO: Implement support for SingleBlockImplicitTerminator
+ block = op.regions[0].blocks.append()
+ with InsertionPoint(block):
+ linalg.YieldOp(values=[])
+ std.ReturnOp([op.result])
def run():
- with Context() as c, Location.unknown():
- module = Module.create()
- # TODO: This at_block_terminator vs default construct distinction feels
- # wrong and is error-prone.
- with InsertionPoint.at_block_terminator(module.body):
- build_matmul_func('main', 18, 32, 96, F32Type.get())
+ with Context() as c, Location.unknown():
+ module = Module.create()
+ # TODO: This at_block_terminator vs default construct distinction feels
+ # wrong and is error-prone.
+ with InsertionPoint.at_block_terminator(module.body):
+ build_matmul_buffers_func('main_buffers', 18, 32, 96, F32Type.get())
+ build_matmul_tensors_func('main_tensors', 18, 32, 96, F32Type.get())
- print(module)
- print(module.operation.get_asm(print_generic_op_form=True))
+ print(module)
-if __name__ == '__main__': run()
+if __name__ == '__main__':
+ run()
diff --git a/mlir/lib/Bindings/Python/.style.yapf b/mlir/lib/Bindings/Python/.style.yapf
new file mode 100644
index 000000000000..9ef1dc15ba62
--- /dev/null
+++ b/mlir/lib/Bindings/Python/.style.yapf
@@ -0,0 +1,4 @@
+[style]
+ based_on_style = google
+ column_limit = 80
+ indent_width = 2
diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index 827348913744..1749ea2e5472 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -10,6 +10,7 @@ set(PY_SRC_FILES
mlir/_dlloader.py
mlir/ir.py
mlir/dialects/__init__.py
+ mlir/dialects/_linalg.py
mlir/ir.py
mlir/passmanager.py
mlir/transforms/__init__.py
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
index 56398ea5b64a..9c003b415438 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
@@ -5,7 +5,68 @@
# Re-export the parent _cext so that every level of the API can get it locally.
from .. import _cext
-def _segmented_accessor(elements, raw_segments, idx):
+__all__ = [
+ "equally_sized_accessor",
+ "extend_opview_class",
+ "get_default_loc_context",
+ "segmented_accessor",
+]
+
+
+def extend_opview_class(ext_module):
+ """Decorator to extend an OpView class from an extension module.
+
+ Extension modules can expose various entry-points:
+ def select_opview_mixin(parent_opview_cls):
+ If defined, allows an appropriate mixin class to be selected dynamically
+ based on the parent OpView class. Should return NotImplemented if a
+ decision is not made.
+
+ Stand-alone class with the same name as a parent OpView class (i.e.
+ "ReturnOp").
+
+ Args:
+ ext_module: A module from which to locate extensions. Can be None if not
+ available.
+
+ Returns:
+ A decorator that takes an OpView subclass and further extends it as
+ needed.
+ """
+
+ def class_decorator(parent_opview_cls: type):
+ if ext_module is None:
+ return parent_opview_cls
+ mixin_cls = NotImplemented
+ try:
+ select_mixin = getattr(ext_module, "select_opview_mixin")
+ except AttributeError:
+ # Try to default resolve it.
+ try:
+ select_mixin = getattr(ext_module, parent_opview_cls.__name__)
+ except AttributeError:
+ pass
+ else:
+ mixin_cls = select_mixin(parent_opview_cls)
+ if mixin_cls is NotImplemented or mixin_cls is None:
+ return parent_opview_cls
+
+ # Have a mixin_cls. Create an appropriate subclass.
+ try:
+
+ class LocalOpView(mixin_cls, parent_opview_cls):
+ pass
+ except TypeError as e:
+ raise TypeError(
+ f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e
+ LocalOpView.__name__ = parent_opview_cls.__name__
+ LocalOpView.__qualname__ = parent_opview_cls.__qualname__
+ return LocalOpView
+
+ return class_decorator
+
+
+def segmented_accessor(elements, raw_segments, idx):
"""
Returns a slice of elements corresponding to the idx-th segment.
@@ -20,8 +81,8 @@ def _segmented_accessor(elements, raw_segments, idx):
return elements[start:end]
-def _equally_sized_accessor(elements, n_variadic, n_preceding_simple,
- n_preceding_variadic):
+def equally_sized_accessor(elements, n_variadic, n_preceding_simple,
+ n_preceding_variadic):
"""
Returns a starting position and a number of elements per variadic group
assuming equally-sized groups and the given numbers of preceding groups.
@@ -42,7 +103,8 @@ def _equally_sized_accessor(elements, n_variadic, n_preceding_simple,
start = n_preceding_simple + n_preceding_variadic * elements_per_group
return start, elements_per_group
-def _get_default_loc_context(location = None):
+
+def get_default_loc_context(location=None):
"""
Returns a context in which the defaulted location is created. If the location
is None, takes the current location from the stack, raises ValueError if there
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py
new file mode 100644
index 000000000000..574098a65567
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py
@@ -0,0 +1,28 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+
+class StructuredOpMixin:
+ """All structured ops use the same mixin class."""
+
+ def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
+ if outputs and results:
+ raise ValueError(
+ "Structured ops must have outputs or results, but not both.")
+ super().__init__(
+ self._ods_build_default(operands=[list(inputs),
+ list(outputs)],
+ results=list(results),
+ loc=loc,
+ ip=ip))
+
+
+def select_opview_mixin(parent_opview_cls):
+ # TODO: This shouldn't be a heuristic: we should have a way to annotate
+ # the OpView to note that it is a structured op.
+ if ("__init__" not in parent_opview_cls.__dict__ and
+ hasattr(parent_opview_cls, "inputs") and
+ hasattr(parent_opview_cls, "outputs") and
+ hasattr(parent_opview_cls, "result_tensors")):
+ return StructuredOpMixin
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 658ad75eea28..0197bfb15577 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -23,12 +23,19 @@ using namespace mlir;
using namespace mlir::tblgen;
/// File header and includes.
+/// {0} is the dialect namespace.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from . import _cext as _ods_cext
-from . import _segmented_accessor as _ods_segmented_accessor, _equally_sized_accessor as _ods_equally_sized_accessor, _get_default_loc_context as _ods_get_default_loc_context
+from . import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context
_ods_ir = _ods_cext.ir
+
+try:
+ from . import _{0} as _ods_ext_module
+except ImportError:
+ _ods_ext_module = None
+
)Py";
/// Template for dialect class:
@@ -46,6 +53,7 @@ class _Dialect(_ods_ir.Dialect):
/// {1} is the operation name.
constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
+ at _ods_extend_opview_class(_ods_ext_module)
class {0}(_ods_ir.OpView):
OPERATION_NAME = "{1}"
)Py";
@@ -706,7 +714,7 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
AttributeClasses attributeClasses;
constructAttributeMapping(records, attributeClasses);
- os << fileHeader;
+ os << llvm::formatv(fileHeader, clDialectName.getValue());
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
More information about the llvm-branch-commits
mailing list