[llvm-branch-commits] [mlir] 894d88a - [mlir][python] Add facility for extending generated python ODS.

Stella Laurenzo via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 19 13:25:34 PST 2021


Author: Stella Laurenzo
Date: 2021-01-19T13:20:26-08:00
New Revision: 894d88a759c9376de4a48ed99c965aac97839b6c

URL: https://github.com/llvm/llvm-project/commit/894d88a759c9376de4a48ed99c965aac97839b6c
DIFF: https://github.com/llvm/llvm-project/commit/894d88a759c9376de4a48ed99c965aac97839b6c.diff

LOG: [mlir][python] Add facility for extending generated python ODS.

* This isn't exclusive with other mechanisms for more ODS centric op definitions, but based on discussions, we feel that we will always benefit from a python escape hatch, and that is the most natural way to write things that don't fit the mold.
* I suspect this facility needs further tweaking, and once it settles, I'll document it and add more tests.
* Added extensions for linalg, since it is unusable without them and continued to evolve my e2e example.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D94752

Added: 
    mlir/examples/python/.style.yapf
    mlir/lib/Bindings/Python/.style.yapf
    mlir/lib/Bindings/Python/mlir/dialects/_linalg.py

Modified: 
    mlir/examples/python/linalg_matmul.py
    mlir/lib/Bindings/Python/CMakeLists.txt
    mlir/lib/Bindings/Python/mlir/dialects/__init__.py
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/examples/python/.style.yapf b/mlir/examples/python/.style.yapf
new file mode 100644
index 000000000000..9ef1dc15ba62
--- /dev/null
+++ b/mlir/examples/python/.style.yapf
@@ -0,0 +1,4 @@
+[style]
+  based_on_style = google
+  column_limit = 80
+  indent_width = 2

diff  --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py
index 83dc15eda9b6..e9be189bfaaf 100644
--- a/mlir/examples/python/linalg_matmul.py
+++ b/mlir/examples/python/linalg_matmul.py
@@ -15,59 +15,69 @@
 
 # TODO: This should be in the core API.
 def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
-    """Creates a |func| op.
+  """Creates a |func| op.
     TODO: This should really be in the MLIR API.
     Returns:
       (operation, entry_block)
     """
-    attrs = {
-        "type": TypeAttr.get(func_type),
-        "sym_name": StringAttr.get(name),
-    }
-    op = Operation.create("func", regions=1, attributes=attrs)
-    body_region = op.regions[0]
-    entry_block = body_region.blocks.append(*func_type.inputs)
-    return op, entry_block
+  attrs = {
+      "type": TypeAttr.get(func_type),
+      "sym_name": StringAttr.get(name),
+  }
+  op = Operation.create("func", regions=1, attributes=attrs)
+  body_region = op.regions[0]
+  entry_block = body_region.blocks.append(*func_type.inputs)
+  return op, entry_block
 
 
-# TODO: Generate customs builder vs patching one in.
-def PatchMatmulOpInit(self, lhs, rhs, result, loc=None, ip=None):
-    super(linalg.MatmulOp, self).__init__(
-        self._ods_build_default(operands=[[lhs, rhs], [result]],
-                                results=[],
-                                loc=loc,
-                                ip=ip))
+def build_matmul_buffers_func(func_name, m, k, n, dtype):
+  lhs_type = MemRefType.get(dtype, [m, k])
+  rhs_type = MemRefType.get(dtype, [k, n])
+  result_type = MemRefType.get(dtype, [m, n])
+  # TODO: There should be a one-liner for this.
+  func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
+  _, entry = FuncOp(func_name, func_type)
+  lhs, rhs, result = entry.arguments
+  with InsertionPoint(entry):
+    op = linalg.MatmulOp([lhs, rhs], [result])
     # TODO: Implement support for SingleBlockImplicitTerminator
-    block = self.regions[0].blocks.append()
+    block = op.regions[0].blocks.append()
     with InsertionPoint(block):
         linalg.YieldOp(values=[])
 
-linalg.MatmulOp.__init__ = PatchMatmulOpInit
+    std.ReturnOp([])
 
 
-def build_matmul_func(func_name, m, k, n, dtype):
-    lhs_type = MemRefType.get(dtype, [m, k])
-    rhs_type = MemRefType.get(dtype, [k, n])
-    result_type = MemRefType.get(dtype, [m, n])
-    # TODO: There should be a one-liner for this.
-    func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
-    _, entry = FuncOp(func_name, func_type)
-    lhs, rhs, result = entry.arguments
-    with InsertionPoint(entry):
-        linalg.MatmulOp(lhs, rhs, result)
-        std.ReturnOp([])
+def build_matmul_tensors_func(func_name, m, k, n, dtype):
+  # TODO: MemRefType and TensorTypes should not have inverted dtype/shapes
+  # from each other.
+  lhs_type = RankedTensorType.get([m, k], dtype)
+  rhs_type = RankedTensorType.get([k, n], dtype)
+  result_type = RankedTensorType.get([m, n], dtype)
+  # TODO: There should be a one-liner for this.
+  func_type = FunctionType.get([lhs_type, rhs_type], [result_type])
+  _, entry = FuncOp(func_name, func_type)
+  lhs, rhs = entry.arguments
+  with InsertionPoint(entry):
+    op = linalg.MatmulOp([lhs, rhs], results=[result_type])
+    # TODO: Implement support for SingleBlockImplicitTerminator
+    block = op.regions[0].blocks.append()
+    with InsertionPoint(block):
+        linalg.YieldOp(values=[])
+    std.ReturnOp([op.result])
 
 
 def run():
-    with Context() as c, Location.unknown():
-        module = Module.create()
-        # TODO: This at_block_terminator vs default construct distinction feels
-        # wrong and is error-prone.
-        with InsertionPoint.at_block_terminator(module.body):
-            build_matmul_func('main', 18, 32, 96, F32Type.get())
+  with Context() as c, Location.unknown():
+    module = Module.create()
+    # TODO: This at_block_terminator vs default construct distinction feels
+    # wrong and is error-prone.
+    with InsertionPoint.at_block_terminator(module.body):
+      build_matmul_buffers_func('main_buffers', 18, 32, 96, F32Type.get())
+      build_matmul_tensors_func('main_tensors', 18, 32, 96, F32Type.get())
 
-        print(module)
-        print(module.operation.get_asm(print_generic_op_form=True))
+    print(module)
 
 
-if __name__ == '__main__': run()
+if __name__ == '__main__':
+  run()

diff  --git a/mlir/lib/Bindings/Python/.style.yapf b/mlir/lib/Bindings/Python/.style.yapf
new file mode 100644
index 000000000000..9ef1dc15ba62
--- /dev/null
+++ b/mlir/lib/Bindings/Python/.style.yapf
@@ -0,0 +1,4 @@
+[style]
+  based_on_style = google
+  column_limit = 80
+  indent_width = 2

diff  --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index 827348913744..1749ea2e5472 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -10,6 +10,7 @@ set(PY_SRC_FILES
   mlir/_dlloader.py
   mlir/ir.py
   mlir/dialects/__init__.py
+  mlir/dialects/_linalg.py
   mlir/ir.py
   mlir/passmanager.py
   mlir/transforms/__init__.py

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
index 56398ea5b64a..9c003b415438 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
@@ -5,7 +5,68 @@
 # Re-export the parent _cext so that every level of the API can get it locally.
 from .. import _cext
 
-def _segmented_accessor(elements, raw_segments, idx):
+__all__ = [
+    "equally_sized_accessor",
+    "extend_opview_class",
+    "get_default_loc_context",
+    "segmented_accessor",
+]
+
+
+def extend_opview_class(ext_module):
+  """Decorator to extend an OpView class from an extension module.
+
+  Extension modules can expose various entry-points:
+    def select_opview_mixin(parent_opview_cls):
+      If defined, allows an appropriate mixin class to be selected dynamically
+      based on the parent OpView class. Should return NotImplemented if a
+      decision is not made.
+
+    Stand-alone class with the same name as a parent OpView class (i.e.
+    "ReturnOp").
+
+  Args:
+    ext_module: A module from which to locate extensions. Can be None if not
+      available.
+
+  Returns:
+    A decorator that takes an OpView subclass and further extends it as
+    needed.
+  """
+
+  def class_decorator(parent_opview_cls: type):
+    if ext_module is None:
+      return parent_opview_cls
+    mixin_cls = NotImplemented
+    try:
+      select_mixin = getattr(ext_module, "select_opview_mixin")
+    except AttributeError:
+      # Try to default resolve it.
+      try:
+        select_mixin = getattr(ext_module, parent_opview_cls.__name__)
+      except AttributeError:
+        pass
+    else:
+      mixin_cls = select_mixin(parent_opview_cls)
+    if mixin_cls is NotImplemented or mixin_cls is None:
+      return parent_opview_cls
+
+    # Have a mixin_cls. Create an appropriate subclass.
+    try:
+
+      class LocalOpView(mixin_cls, parent_opview_cls):
+        pass
+    except TypeError as e:
+      raise TypeError(
+          f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e
+    LocalOpView.__name__ = parent_opview_cls.__name__
+    LocalOpView.__qualname__ = parent_opview_cls.__qualname__
+    return LocalOpView
+
+  return class_decorator
+
+
+def segmented_accessor(elements, raw_segments, idx):
   """
   Returns a slice of elements corresponding to the idx-th segment.
 
@@ -20,8 +81,8 @@ def _segmented_accessor(elements, raw_segments, idx):
   return elements[start:end]
 
 
-def _equally_sized_accessor(elements, n_variadic, n_preceding_simple,
-                            n_preceding_variadic):
+def equally_sized_accessor(elements, n_variadic, n_preceding_simple,
+                           n_preceding_variadic):
   """
   Returns a starting position and a number of elements per variadic group
   assuming equally-sized groups and the given numbers of preceding groups.
@@ -42,7 +103,8 @@ def _equally_sized_accessor(elements, n_variadic, n_preceding_simple,
   start = n_preceding_simple + n_preceding_variadic * elements_per_group
   return start, elements_per_group
 
-def _get_default_loc_context(location = None):
+
+def get_default_loc_context(location=None):
   """
   Returns a context in which the defaulted location is created. If the location
   is None, takes the current location from the stack, raises ValueError if there

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py
new file mode 100644
index 000000000000..574098a65567
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py
@@ -0,0 +1,28 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+
+class StructuredOpMixin:
+  """All structured ops use the same mixin class."""
+
+  def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
+    if outputs and results:
+      raise ValueError(
+          "Structured ops must have outputs or results, but not both.")
+    super().__init__(
+        self._ods_build_default(operands=[list(inputs),
+                                          list(outputs)],
+                                results=list(results),
+                                loc=loc,
+                                ip=ip))
+
+
+def select_opview_mixin(parent_opview_cls):
+  # TODO: This shouldn't be a heuristic: we should have a way to annotate
+  # the OpView to note that it is a structured op.
+  if ("__init__" not in parent_opview_cls.__dict__ and
+      hasattr(parent_opview_cls, "inputs") and
+      hasattr(parent_opview_cls, "outputs") and
+      hasattr(parent_opview_cls, "result_tensors")):
+    return StructuredOpMixin

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 658ad75eea28..0197bfb15577 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -23,12 +23,19 @@ using namespace mlir;
 using namespace mlir::tblgen;
 
 /// File header and includes.
+///   {0} is the dialect namespace.
 constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from . import _cext as _ods_cext
-from . import _segmented_accessor as _ods_segmented_accessor, _equally_sized_accessor as _ods_equally_sized_accessor, _get_default_loc_context as _ods_get_default_loc_context
+from . import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context
 _ods_ir = _ods_cext.ir
+
+try:
+  from . import _{0} as _ods_ext_module
+except ImportError:
+  _ods_ext_module = None
+
 )Py";
 
 /// Template for dialect class:
@@ -46,6 +53,7 @@ class _Dialect(_ods_ir.Dialect):
 ///   {1} is the operation name.
 constexpr const char *opClassTemplate = R"Py(
 @_ods_cext.register_operation(_Dialect)
+ at _ods_extend_opview_class(_ods_ext_module)
 class {0}(_ods_ir.OpView):
   OPERATION_NAME = "{1}"
 )Py";
@@ -706,7 +714,7 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
   AttributeClasses attributeClasses;
   constructAttributeMapping(records, attributeClasses);
 
-  os << fileHeader;
+  os << llvm::formatv(fileHeader, clDialectName.getValue());
   os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
     Operator op(rec);


        


More information about the llvm-branch-commits mailing list