[Mlir-commits] [mlir] 2f367f3 - [mlir][Linalg] Allow calling named ops when available and make it the default.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Mar 29 06:27:42 PDT 2021


Author: Nicolas Vasilache
Date: 2021-03-29T13:23:11Z
New Revision: 2f367f34fdeb1b185f1e39c192a77e19bfa3f16c

URL: https://github.com/llvm/llvm-project/commit/2f367f34fdeb1b185f1e39c192a77e19bfa3f16c
DIFF: https://github.com/llvm/llvm-project/commit/2f367f34fdeb1b185f1e39c192a77e19bfa3f16c.diff

LOG: [mlir][Linalg] Allow calling named ops when available and make it the default.

Differential Revision: https://reviews.llvm.org/D99419

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
    mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py
    mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
    mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/test/Bindings/Python/dialects/linalg/ops.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
index d35d10cc4b8eb..d787943d16372 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
@@ -49,9 +49,6 @@ class StructuredOpMixin:
   """All structured ops use the same mixin class."""
 
   def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
-    if outputs and results:
-      raise ValueError(
-          "Structured ops must have outputs or results, but not both.")
     super().__init__(
         self.build_generic(results=list(results),
                            operands=[list(inputs), list(outputs)],

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py
index 81949b8f881cb..9767183371119 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py
@@ -2,4 +2,52 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+# These are the backing OpView classes generated from the linalg tablegen
+# definitions following these steps:
+#   DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
 from .._linalg_ops_gen import *
+
+# These are the ground truth functions defined as:
+# ```
+#    @linalg_structured_op
+#    def matmul(A=TensorDef(T1, S.M, S.K),
+#               B=TensorDef(T2, S.K, S.N),
+#               C=TensorDef(U, S.M, S.N, output=True)):
+# ```
+# using the linalg-py eDSL.
+# The linalg-py eDSL builds a python representation (PyRepr) that is 
+# used in following ways:
+#  1. PyRepr -> YAML to generate the C++ and Python .td files. These
+#     then turn into the core C++ Op classes and Python OpView classes
+#     respectively (made available in _linalg_ops_gen). The generic OpView class 
+#     mechanism makes the C++ classes available to python through the CAPI.
+#     PyRepr -> YAML currently occurs before compiler compile time.
+#     The other steps in this category occur at compiler compile time.
+#  2. PyRepr -> linalg.core_named_ops calls: piggybacks on the 
+#     _linalg_ops_gen classes and the OpView mechanism to build IR at
+#     runtime in python:
+#       a. by default, the Named Op Form is emitted, e.g.:
+#          `linalg.matmul(lhs, rhs, outs=[out])` creates the following IR:
+#          ```
+#             %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) 
+#                               outs(%0 : tensor<4x8xf32>)
+#                  -> tensor<4x8xf32>   
+#          ```
+#       b. by setting emit_generic=True, the Generic Op Form is emitted, e.g.:
+#           `linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)` creates the following IR:
+#          ```
+#             %1 = linalg.generic {indexing_maps = [...], iterator_types = [...]} 
+#               ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) 
+#              outs(%0 : tensor<4x8xf32>) {
+#               ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  
+#                  ...
+#                  linalg.yield %3 : f32
+#             } -> tensor<4x8xf32>  
+#          ```
+#  3. PyRepr -> Runtime Custom Op definitions: directly generates a
+#     linalg.generic form like in 2.b.
+#     !!!WARNING!!!: if one creates a runtime custom op with the same name 
+#     as an existing core named op, step 2. will likely take precedence.
+#     TODO: guard against surprises and fail create Runtime Custom Ops with 
+#     the same name as existing Core Named Ops.
+from .opdsl.ops.core_named_ops import *

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 6bc6ff97987a7..85da3323cac6d 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -359,16 +359,16 @@ class OpMetadataDef(YAMLObject):
   """Metadata about the op (generally not behavior impacting)."""
   yaml_tag = "!LinalgOpMetadata"
 
-  def __init__(self, name: str, cpp_op_name: Optional[str], doc: Optional[str]):
+  def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]):
     self.name = name
-    self.cpp_op_name = cpp_op_name if cpp_op_name is not None else name
+    self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
     self.doc = doc
     self.implements = []  # type: List[OpInterfaceDef]
 
   def to_yaml_custom_dict(self):
     d = dict(
         name=self.name,
-        cpp_op_name=self.cpp_op_name,
+        cpp_class_name=self.cpp_class_name,
         doc=self.doc,
     )
     if self.implements:
@@ -381,9 +381,9 @@ class LinalgOpDef:
 
   def __init__(self,
                name: str,
-               cpp_op_name: Optional[str] = None,
+               cpp_class_name: Optional[str] = None,
                doc: Optional[str] = None):
-    self.metadata = OpMetadataDef(name=name, cpp_op_name=cpp_op_name, doc=doc)
+    self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc)
     self.registered_tensors = dict()  # type: Dict[str, TensorDef]
     self.comprehensions = list()  # type: List[Comprehension]
     self._affine_state = AffineBuildState()
@@ -413,7 +413,7 @@ def tensor(self, name):
 
   def __repr__(self):
     lines = [
-        f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_op_name},"
+        f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"
     ]
     for name, tensor in self.registered_tensors.items():
       lines.append(f"  {tensor}")

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
index cbff41db2d889..d6dc9895f89a8 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -44,7 +44,7 @@ def __init__(self, op_name: str, model: LinalgOpDef):
     self.op_name = op_name
     self.model = model
 
-  def __call__(self, *args, emit_generic: bool = True, **kwargs):
+  def __call__(self, *args, emit_generic: bool = False, **kwargs):
     """Emits the corresponding op definition as IR.
 
     Most arguments are passed through to the underlying emitter. The following
@@ -61,14 +61,21 @@ def __call__(self, *args, emit_generic: bool = True, **kwargs):
       raise NotImplementedError(
           f"Emission of composite linalg ops not supported: {op_configs}")
 
+    # TODO: this file should probably not be called dsl.py but rather is a client
+    # of the dsl.py.
+    from .... import linalg as linalg_ops
+    emit_generic = (emit_generic or 
+      (not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys()))
+
     op_config = op_configs[0]
     if op_config.structured_op:
       if emit_generic:
         return emit_generic_structured_op(op_config.structured_op, *args,
                                           **kwargs)
       else:
-        return emit_named_structured_op(op_config.structured_op, *args,
-                                        **kwargs)
+        return emit_named_structured_op(
+          op_config.structured_op, self.op_name,
+          self.model.metadata.cpp_class_name, *args, **kwargs)
 
     raise NotImplementedError(
         f"Emission of linalg op type not supported: {op_config}")
@@ -91,7 +98,7 @@ def linalg_structured_op(dsl_func=None,
     op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
 
   tc_model = LinalgOpDef(name=op_name,
-                         cpp_op_name=op_class_name,
+                         cpp_class_name=op_class_name,
                          doc=inspect.getdoc(dsl_func))
 
   # Extract arguments and TensorDefs from the signature.

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 9a18993e9f627..e8e7eb5c3463e 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -17,9 +17,9 @@
 ]
 
 
-def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
-                               *ins: Value,
-                               outs: Value = ()):
+def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
+                                 *ins: Value,
+                                 outs: Value):
   all_arg_defs = op_config.ordered_tensor_args
   in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"]
   out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"]
@@ -49,6 +49,18 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
       [AffineMapAttr.get(am) for am in op_config.indexing_maps])
   iterator_types_attr = ArrayAttr.get(
       [StringAttr.get(s) for s in op_config.iterator_types])
+
+  return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types,
+          type_mapping, indexing_maps_attr, iterator_types_attr)
+
+
+def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
+                               *ins: Value,
+                               outs: Value = ()):
+  all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \
+  type_mapping, indexing_maps_attr, iterator_types_attr =   \
+     prepare_common_structured_op(op_config, *ins, outs = outs)
+
   generic_op = linalg.GenericOp(
       result_tensors=out_types,
       inputs=ins,
@@ -77,10 +89,23 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
 
 
 def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
+                             op_name: str,
+                             op_class_name: str,
                              *ins: Value,
                              outs: Value = ()):
-  raise NotImplementedError(
-      f"Emission of named structured ops is not supported: {op_config}")
+  all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \
+  type_mapping, indexing_maps_attr, iterator_types_attr =   \
+     prepare_common_structured_op(op_config, *ins, outs = outs)
+
+  if not op_class_name in linalg.__dict__.keys():
+    raise NotImplementedError(
+        f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
+
+  named_op = getattr(linalg, op_class_name)(ins, outs, out_types)
+  if len(out_arg_defs) == 1:
+    return named_op.result
+  else:
+    return named_op.results
 
 
 class _BodyBuilder:

diff  --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py
index 22ed09e0716ce..8f2eb06004cee 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/ops.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py
@@ -75,8 +75,33 @@ def testStructuredOpOnBuffers():
                                 results=[]))
       with InsertionPoint(func.add_entry_block()):
         lhs, rhs, result = func.entry_block.arguments
+        # TODO: prperly hook up the region.
         linalg.MatmulOp([lhs, rhs], outputs=[result])
         std.ReturnOp([])
 
   # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
   print(module)
+
+# CHECK-LABEL: TEST: testNamedStructuredOp
+ at run
+def testNamedStructuredOp():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
+      @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
+                                   RankedTensorType.get((16, 8), f32))
+      def named_form(lhs, rhs):
+        init_result = linalg.InitTensorOp([4, 8], f32)
+        # CHECK: linalg.matmul
+        # TODO: prperly hook up the region.
+        return linalg.matmul(lhs, rhs, outs=[init_result.result])
+
+      @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
+                                   RankedTensorType.get((16, 8), f32))
+      def generic_form(lhs, rhs):
+        init_result = linalg.InitTensorOp([4, 8], f32)
+        # CHECK: linalg.generic
+        return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True)
+
+  print(module)


        


More information about the Mlir-commits mailing list