[Mlir-commits] [mlir] b164f23 - [mlir][python] support taking ops instead of values in op constructors

Alex Zinenko llvmlistbot at llvm.org
Fri Oct 8 00:49:54 PDT 2021


Author: Alex Zinenko
Date: 2021-10-08T09:49:48+02:00
New Revision: b164f23c29fdf8b1e82fc4cfeab79c9fb6df918d

URL: https://github.com/llvm/llvm-project/commit/b164f23c29fdf8b1e82fc4cfeab79c9fb6df918d
DIFF: https://github.com/llvm/llvm-project/commit/b164f23c29fdf8b1e82fc4cfeab79c9fb6df918d.diff

LOG: [mlir][python] support taking ops instead of values in op constructors

Introduce support for accepting ops instead of values when constructing ops. A
single-result op can be used instead of a value, including in lists of values,
and any op can be used instead of a list of values. This is similar to, but
more powerful, than the C++ API that allows for implicitly casting an OpType to
Value if it is statically known to have a single result - the cast in Python is
based on the op dynamically having a single result, and also handles the
multi-result case. This allows to build IR in a more concise way:

    op = dialect.produce_multiple_results()
    other = dialect.produce_single_result()
    dialect.consume_multiple_results(other, op)

instead of having to access the results manually

    op = dialect.produce.multiple_results()
    other = dialect.produce_single_result()
    dialect.consume_multiple_results(other.result, op.operation.results)

The dispatch is implemented directly in Python and is triggered automatically
for autogenerated OpView subclasses. Extension OpView classes should use the
functions provided in ods_common.py if they want to implement this behavior.
An alternative could be to implement the dispatch in the C++ bindings code, but
it would require to forward opaque types through all Python functions down to a
binding call, which makes it hard to inspect them in Python, e.g., to obtain
the types of values.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D111306

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_linalg_ops_ext.py
    mlir/python/mlir/dialects/_ods_common.py
    mlir/python/mlir/dialects/_scf_ops_ext.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/test/python/dialects/linalg/ops.py
    mlir/test/python/dialects/scf.py
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index 5360967492d5c..b7641c0a4b53c 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -10,6 +10,7 @@
 except ImportError as e:
   raise RuntimeError("Error loading imports from extension module") from e
 
+from ._ods_common import get_op_result_or_value as _get_op_result_or_value
 
 def isa(cls: Type, ty: Type):
   try:
@@ -26,11 +27,12 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
     results = []
     if isa(RankedTensorType, output.type):
       results = [output.type]
-    op = self.build_generic(results=results,
-                            operands=[value, output],
-                            attributes=None,
-                            loc=loc,
-                            ip=ip)
+    op = self.build_generic(
+        results=results,
+        operands=[_get_op_result_or_value(o) for o in [value, output]],
+        attributes=None,
+        loc=loc,
+        ip=ip)
     OpView.__init__(self, op)
     linalgDialect = Context.current.get_dialect_descriptor("linalg")
     fill_builtin_region(linalgDialect, self.operation)

diff  --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 2fbf3545f46d4..95c44186533f1 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -5,11 +5,14 @@
 # Provide a convenient name for sub-packages to resolve the main C-extension
 # with a relative import.
 from .._mlir_libs import _mlir as _cext
+from typing import Sequence as _Sequence, Union as _Union
 
 __all__ = [
     "equally_sized_accessor",
     "extend_opview_class",
     "get_default_loc_context",
+    "get_op_result_or_value",
+    "get_op_results_or_values",
     "segmented_accessor",
 ]
 
@@ -118,3 +121,38 @@ def get_default_loc_context(location=None):
     # Location.current raises ValueError if there is no current location.
     return _cext.ir.Location.current.context
   return location.context
+
+
+def get_op_result_or_value(
+    arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]
+) -> _cext.ir.Value:
+  """Returns the given value or the single result of the given op.
+
+  This is useful to implement op constructors so that they can take other ops as
+  arguments instead of requiring the caller to extract results for every op.
+  Raises ValueError if provided with an op that doesn't have a single result.
+  """
+  if isinstance(arg, _cext.ir.OpView):
+    return arg.operation.result
+  elif isinstance(arg, _cext.ir.Operation):
+    return arg.result
+  else:
+    assert isinstance(arg, _cext.ir.Value)
+    return arg
+
+
+def get_op_results_or_values(
+    arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]]
+) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
+  """Returns the given sequence of values or the results of the given op.
+
+  This is useful to implement op constructors so that they can take other ops as
+  lists of arguments instead of requiring the caller to extract results for
+  every op.
+  """
+  if isinstance(arg, _cext.ir.OpView):
+    return arg.operation.results
+  elif isinstance(arg, _cext.ir.Operation):
+    return arg.results
+  else:
+    return arg

diff  --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
index c6532a75632b5..a8924a7507a42 100644
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -7,8 +7,8 @@
 except ImportError as e:
   raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Any, Sequence
-
+from typing import Any, Optional, Sequence, Union
+from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
 
 class ForOp:
   """Specialization for the SCF for op class."""
@@ -17,7 +17,8 @@ def __init__(self,
                lower_bound,
                upper_bound,
                step,
-               iter_args: Sequence[Any] = [],
+               iter_args: Optional[Union[Operation, OpView,
+                                         Sequence[Value]]] = None,
                *,
                loc=None,
                ip=None):
@@ -26,14 +27,22 @@ def __init__(self,
     - `lower_bound` is the value to use as lower bound of the loop.
     - `upper_bound` is the value to use as upper bound of the loop.
     - `step` is the value to use as loop step.
-    - `iter_args` is a list of additional loop-carried arguments.
+    - `iter_args` is a list of additional loop-carried arguments or an operation
+      producing them as results.
     """
+    if iter_args is None:
+      iter_args = []
+    iter_args = _get_op_results_or_values(iter_args)
+
     results = [arg.type for arg in iter_args]
     super().__init__(
         self.build_generic(
             regions=1,
             results=results,
-            operands=[lower_bound, upper_bound, step] + list(iter_args),
+            operands=[
+                _get_op_result_or_value(o)
+                for o in [lower_bound, upper_bound, step]
+            ] + list(iter_args),
             loc=loc,
             ip=ip))
     self.regions[0].blocks.append(IndexType.get(), *results)

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 047bde245b645..1acae7a7a389a 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -2,7 +2,7 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Dict, List
+from typing import Dict, List, Sequence, Union
 
 from contextlib import contextmanager
 import functools
@@ -10,12 +10,15 @@
 import threading
 
 from ..... import ir
+from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
 from .comprehension import *
 from .config import *
 from .emitter import *
 
 _CONTEXT = threading.local()
 
+StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
+                         Sequence[Union[ir.Value, ir.Operation, ir.OpView]]]
 
 @contextmanager
 def bind_op_def(model: LinalgOpDef):
@@ -37,6 +40,15 @@ def current_op_def() -> LinalgOpDef:
         "but none is set. Did you mean to call this in an op definition?")
 
 
+def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
+  if isinstance(outs, (ir.Operation, ir.OpView)):
+    return _get_op_results_or_values(outs)
+  elif isinstance(outs, ir.OpResultList):
+    return outs
+
+  return [_get_op_result_or_value(o) for o in outs]
+
+
 class DefinedOpCallable:
   """Callable that wraps any defined op function."""
 
@@ -44,7 +56,8 @@ def __init__(self, op_name: str, model: LinalgOpDef):
     self.op_name = op_name
     self.model = model
 
-  def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs):
+  def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
+               outs: StructuredOpOuts, **kwargs):
     """Emits the corresponding op definition as IR.
 
     Most arguments are passed through to the underlying emitter. The following
@@ -73,17 +86,19 @@ def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs):
         emit_generic or not ctx.is_registered_operation(fully_qualified_name))
 
     op_config = op_configs[0]
+    out_values = _prepare_structured_op_outs(outs)
+    in_values = [_get_op_result_or_value(i) for i in ins]
     if op_config.structured_op:
       if emit_generic:
         return emit_generic_structured_op(
-            op_config.structured_op, *ins, outs=outs, **kwargs)
+            op_config.structured_op, *in_values, outs=out_values, **kwargs)
       else:
         return emit_named_structured_op(
             op_config.structured_op,
             self.op_name,
             self.model.metadata.cpp_class_name,
-            *ins,
-            outs=outs,
+            *in_values,
+            outs=out_values,
             **kwargs)
 
     raise NotImplementedError(

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 7feea040aa77c..021fe83285945 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -2,7 +2,7 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Dict, Sequence
+from typing import Dict, List, Sequence, Tuple, Union
 
 from .....ir import *
 from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
@@ -10,6 +10,7 @@
 from .... import linalg
 from .... import std
 from .... import math
+from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
 
 from .scalar_expr import *
 from .config import *
@@ -18,8 +19,10 @@
 __all__ = [
     "emit_generic_structured_op",
     "emit_named_structured_op",
+    "ValueList",
 ]
 
+ValueList = Union[Sequence[Value], OpResultList]
 
 def isa(cls: Type, ty: Type):
   try:
@@ -30,17 +33,18 @@ def isa(cls: Type, ty: Type):
 
 
 def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
-                                 *ins: Value, outs: Sequence[Value],
+                                 *ins: Value, outs: ValueList,
                                  **attrs: Sequence[int]):
   all_arg_defs = op_config.ordered_operands
   in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"]
   out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"]
   attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"]
 
-  # Verify outs is a sequence.
-  if not isinstance(outs, Sequence):
-    raise ValueError(f"Expected named argument outs to have type Sequence "
-                     f"but got {type(outs)}")
+  # Verify outs is a sequence or a list of results.
+  if not isinstance(outs, (Sequence, OpResultList)):
+    raise ValueError(
+        f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}"
+    )
 
   # Arity validation.
   if len(ins) != len(in_arg_defs):
@@ -122,7 +126,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
 
 
 def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
-                               outs: Sequence[Value], **attrs: Sequence[int]):
+                               outs: ValueList, **attrs: Sequence[int]):
   all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
   indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
      prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@@ -153,8 +157,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
 
 
 def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
-                             op_class_name: str, *ins: Value,
-                             outs: Sequence[Value], **attrs: Sequence[int]):
+                             op_class_name: str, *ins: Value, outs: ValueList,
+                             **attrs: Sequence[int]):
   all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
   indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
      prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@@ -355,11 +359,11 @@ def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
       return std.MinUIOp(lhs.type, lhs, rhs).result
     raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
 
-def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
-                           in_arg_defs: Sequence[OperandDefConfig],
-                           ins: Sequence[Value],
-                           out_arg_defs: Sequence[OperandDefConfig],
-                           outs: Sequence[Value]):
+def _infer_structured_outs(
+    op_config: LinalgStructuredOpConfig,
+    in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value],
+    out_arg_defs: Sequence[OperandDefConfig],
+    outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]:
   """Infers implicit outs and output types.
 
   Respects existing contents of outs if not empty.

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 572c657336686..c3ee0c47aa05a 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -24,9 +24,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
-  // CHECK:   operands.append(variadic1)
-  // CHECK:   operands.append(non_variadic)
-  // CHECK:   if variadic2 is not None: operands.append(variadic2)
+  // CHECK:   operands.append(_get_op_results_or_values(variadic1))
+  // CHECK:   operands.append(_get_op_result_or_value(non_variadic))
+  // CHECK:   if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2))
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -150,8 +150,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
-  // CHECK:   operands.append(_gen_arg_0)
-  // CHECK:   operands.append(_gen_arg_2)
+  // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_0))
+  // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_2))
   // CHECK:   if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
   // CHECK:   if is_ is not None: attributes["is"] = is_
@@ -197,9 +197,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK:   results.append(i32)
   // CHECK:   results.append(_gen_res_1)
   // CHECK:   results.append(i64)
-  // CHECK:   operands.append(_gen_arg_0)
-  // CHECK:   operands.append(f32)
-  // CHECK:   operands.append(_gen_arg_2)
+  // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_0))
+  // CHECK:   operands.append(_get_op_result_or_value(f32))
+  // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_2))
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -230,8 +230,8 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
-  // CHECK:   operands.append(non_variadic)
-  // CHECK:   operands.extend(variadic)
+  // CHECK:   operands.append(_get_op_result_or_value(non_variadic))
+  // CHECK:   operands.extend(_get_op_results_or_values(variadic))
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -285,7 +285,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
-  // CHECK:   operands.append(in_)
+  // CHECK:   operands.append(_get_op_result_or_value(in_))
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -353,8 +353,8 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:   attributes = {}
   // CHECK:   results.append(i64)
   // CHECK:   results.append(f64)
-  // CHECK:   operands.append(i32)
-  // CHECK:   operands.append(f32)
+  // CHECK:   operands.append(_get_op_result_or_value(i32))
+  // CHECK:   operands.append(_get_op_result_or_value(f32))
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 8b990e66e13ec..6f07969cce856 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -185,3 +185,30 @@ def generic_form(lhs, rhs):
         return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True)
 
   print(module)
+
+
+# CHECK-LABEL: TEST: testOpResultFromOtherOp
+ at run
+def testOpResultFromOtherOp():
+  with Context(), Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(
+          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
+                                                                   f32))
+      def pass_an_op_directly(arg0, arg1):
+        one = std.ConstantOp(F32Type.get(), 1.0)
+        # CHECK: %[[LHS:.*]] = linalg.fill
+        lhs = linalg.FillOp(arg0, one)
+        # CHECK: %[[RHS:.*]] = linalg.fill
+        rhs = linalg.FillOp(arg1, one)
+        # CHECK: %[[INIT:.*]] = linalg.init_tensor
+        init = linalg.InitTensorOp([4, 8], f32)
+        # CHECK: linalg.matmul
+        # CHECK: ins(%[[LHS]], %[[RHS]]
+        # CHECK: outs(%[[INIT]]
+        return linalg.matmul(lhs, rhs, outs=init)
+
+  print(module)

diff  --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index d604913b1c4cb..7819679d90985 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -2,53 +2,82 @@
 
 from mlir.ir import *
 from mlir.dialects import scf
+from mlir.dialects import std
 from mlir.dialects import builtin
 
 
-def run(f):
+def constructAndPrintInModule(f):
   print("\nTEST:", f.__name__)
-  f()
+  with Context(), Location.unknown():
+    module = Module.create()
+    with InsertionPoint(module.body):
+      f()
+    print(module)
   return f
 
 
 # CHECK-LABEL: TEST: testSimpleLoop
- at run
+ at constructAndPrintInModule
 def testSimpleLoop():
-  with Context(), Location.unknown():
-    module = Module.create()
-    index_type = IndexType.get()
-    with InsertionPoint(module.body):
+  index_type = IndexType.get()
 
-      @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
-      def simple_loop(lb, ub, step):
-        loop = scf.ForOp(lb, ub, step, [lb, lb])
-        with InsertionPoint(loop.body):
-          scf.YieldOp(loop.inner_iter_args)
-        return
+  @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
+  def simple_loop(lb, ub, step):
+    loop = scf.ForOp(lb, ub, step, [lb, lb])
+    with InsertionPoint(loop.body):
+      scf.YieldOp(loop.inner_iter_args)
+    return
 
-  # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-  # CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
-  # CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
-  # CHECK: scf.yield %[[I1]], %[[I2]]
-  print(module)
+
+# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
+# CHECK: scf.yield %[[I1]], %[[I2]]
 
 
 # CHECK-LABEL: TEST: testInductionVar
- at run
+ at constructAndPrintInModule
 def testInductionVar():
-  with Context(), Location.unknown():
-    module = Module.create()
-    index_type = IndexType.get()
-    with InsertionPoint(module.body):
+  index_type = IndexType.get()
+
+  @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
+  def induction_var(lb, ub, step):
+    loop = scf.ForOp(lb, ub, step, [lb])
+    with InsertionPoint(loop.body):
+      scf.YieldOp([loop.induction_variable])
+    return
+
+
+# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+# CHECK: scf.yield %[[IV]]
+
+
+ at constructAndPrintInModule
+def testOpsAsArguments():
+  index_type = IndexType.get()
+  callee = builtin.FuncOp(
+      "callee", ([], [index_type, index_type]), visibility="private")
+  func = builtin.FuncOp("ops_as_arguments", ([], []))
+  with InsertionPoint(func.add_entry_block()):
+    lb = std.ConstantOp.create_index(0)
+    ub = std.ConstantOp.create_index(42)
+    step = std.ConstantOp.create_index(2)
+    iter_args = std.CallOp(callee, [])
+    loop = scf.ForOp(lb, ub, step, iter_args)
+    with InsertionPoint(loop.body):
+      scf.YieldOp(loop.inner_iter_args)
+    std.ReturnOp([])
+
 
-      @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
-      def induction_var(lb, ub, step):
-        loop = scf.ForOp(lb, ub, step, [lb])
-        with InsertionPoint(loop.body):
-          scf.YieldOp([loop.induction_variable])
-        return
-
-  # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-  # CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
-  # CHECK: scf.yield %[[IV]]
-  print(module)
+# CHECK-LABEL: TEST: testOpsAsArguments
+# CHECK: func private @callee() -> (index, index)
+# CHECK: func @ops_as_arguments() {
+# CHECK:   %[[LB:.*]] = constant 0
+# CHECK:   %[[UB:.*]] = constant 42
+# CHECK:   %[[STEP:.*]] = constant 2
+# CHECK:   %[[ARGS:.*]]:2 = call @callee()
+# CHECK:   scf.for %arg0 = %c0 to %c42 step %c2
+# CHECK:   iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
+# CHECK:     scf.yield %{{.*}}, %{{.*}}
+# CHECK:   return

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 742cad748ca46..51993bbf6b051 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -28,7 +28,7 @@ constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context
+from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
 _ods_ir = _ods_cext.ir
 
 try:
@@ -489,20 +489,25 @@ constexpr const char *initTemplate = R"Py(
 )Py";
 
 /// Template for appending a single element to the operand/result list.
-///   {0} is either 'operand' or 'result';
-///   {1} is the field name.
-constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
+///   {0} is the field name.
+constexpr const char *singleOperandAppendTemplate =
+    "operands.append(_get_op_result_or_value({0}))";
+constexpr const char *singleResultAppendTemplate = "results.append({0})";
 
 /// Template for appending an optional element to the operand/result list.
-///   {0} is either 'operand' or 'result';
-///   {1} is the field name.
-constexpr const char *optionalAppendTemplate =
-    "if {1} is not None: {0}s.append({1})";
-
-/// Template for appending a a list of elements to the operand/result list.
-///   {0} is either 'operand' or 'result';
-///   {1} is the field name.
-constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})";
+///   {0} is the field name.
+constexpr const char *optionalAppendOperandTemplate =
+    "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
+constexpr const char *optionalAppendResultTemplate =
+    "if {0} is not None: results.append({0})";
+
+/// Template for appending a list of elements to the operand/result list.
+///   {0} is the field name.
+constexpr const char *multiOperandAppendTemplate =
+    "operands.extend(_get_op_results_or_values({0}))";
+constexpr const char *multiOperandAppendPackTemplate =
+    "operands.append(_get_op_results_or_values({0}))";
+constexpr const char *multiResultAppendTemplate = "results.extend({0})";
 
 /// Template for setting an attribute in the operation builder.
 ///   {0} is the attribute name;
@@ -625,43 +630,70 @@ static void populateBuilderLinesSuccessors(
 }
 
 /// Populates `builderLines` with additional lines that are required in the
-/// builder. `kind` must be either "operand" or "result". `names` contains the
-/// names of init arguments that correspond to the elements.
-static void populateBuilderLines(
-    const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
-    llvm::SmallVectorImpl<std::string> &builderLines,
-    llvm::function_ref<int(const Operator &)> getNumElements,
-    llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
-        getElement) {
-  bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
+/// builder to set up op operands.
+static void
+populateBuilderLinesOperand(const Operator &op,
+                            llvm::ArrayRef<std::string> names,
+                            llvm::SmallVectorImpl<std::string> &builderLines) {
+  bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
 
   // For each element, find or generate a name.
-  for (int i = 0, e = getNumElements(op); i < e; ++i) {
-    const NamedTypeConstraint &element = getElement(op, i);
+  for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
+    const NamedTypeConstraint &element = op.getOperand(i);
+    std::string name = names[i];
+
+    // Choose the formatting string based on the element kind.
+    llvm::StringRef formatString;
+    if (!element.isVariableLength()) {
+      formatString = singleOperandAppendTemplate;
+    } else if (element.isOptional()) {
+      formatString = optionalAppendOperandTemplate;
+    } else {
+      assert(element.isVariadic() && "unhandled element group type");
+      // If emitting with sizedSegments, then we add the actual list-typed
+      // element. Otherwise, we extend the actual operands.
+      if (sizedSegments) {
+        formatString = multiOperandAppendPackTemplate;
+      } else {
+        formatString = multiOperandAppendTemplate;
+      }
+    }
+
+    builderLines.push_back(llvm::formatv(formatString.data(), name));
+  }
+}
+
+/// Populates `builderLines` with additional lines that are required in the
+/// builder to set up op results.
+static void
+populateBuilderLinesResult(const Operator &op,
+                           llvm::ArrayRef<std::string> names,
+                           llvm::SmallVectorImpl<std::string> &builderLines) {
+  bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
+
+  // For each element, find or generate a name.
+  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+    const NamedTypeConstraint &element = op.getResult(i);
     std::string name = names[i];
 
     // Choose the formatting string based on the element kind.
     llvm::StringRef formatString;
     if (!element.isVariableLength()) {
-      formatString = singleElementAppendTemplate;
+      formatString = singleResultAppendTemplate;
     } else if (element.isOptional()) {
-      formatString = optionalAppendTemplate;
+      formatString = optionalAppendResultTemplate;
     } else {
       assert(element.isVariadic() && "unhandled element group type");
-      // If emitting with sizedSegments, then we add the actual list typed
-      // element using the singleElementAppendTemplate. Otherwise, we extend
-      // the actual operands.
+      // If emitting with sizedSegments, then we add the actual list-typed
+      // element. Otherwise, we extend the actual operands.
       if (sizedSegments) {
-        // Append the list as is.
-        formatString = singleElementAppendTemplate;
+        formatString = singleResultAppendTemplate;
       } else {
-        // Append the list elements.
-        formatString = multiElementAppendTemplate;
+        formatString = multiResultAppendTemplate;
       }
     }
 
-    // Add the lines.
-    builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
+    builderLines.push_back(llvm::formatv(formatString.data(), name));
   }
 }
 
@@ -680,12 +712,10 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
                       op.getNumNativeAttributes() + op.getNumSuccessors());
   populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
 
-  populateBuilderLines(
-      op, "result",
-      llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
-      builderLines, getNumResults, getResult);
-  populateBuilderLines(op, "operand", operandArgNames, builderLines,
-                       getNumOperands, getOperand);
+  populateBuilderLinesResult(
+      op, llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
+      builderLines);
+  populateBuilderLinesOperand(op, operandArgNames, builderLines);
   populateBuilderLinesAttr(
       op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
       builderLines);


        


More information about the Mlir-commits mailing list