[Mlir-commits] [mlir] b164f23 - [mlir][python] support taking ops instead of values in op constructors
Alex Zinenko
llvmlistbot at llvm.org
Fri Oct 8 00:49:54 PDT 2021
Author: Alex Zinenko
Date: 2021-10-08T09:49:48+02:00
New Revision: b164f23c29fdf8b1e82fc4cfeab79c9fb6df918d
URL: https://github.com/llvm/llvm-project/commit/b164f23c29fdf8b1e82fc4cfeab79c9fb6df918d
DIFF: https://github.com/llvm/llvm-project/commit/b164f23c29fdf8b1e82fc4cfeab79c9fb6df918d.diff
LOG: [mlir][python] support taking ops instead of values in op constructors
Introduce support for accepting ops instead of values when constructing ops. A
single-result op can be used instead of a value, including in lists of values,
and any op can be used instead of a list of values. This is similar to, but
more powerful, than the C++ API that allows for implicitly casting an OpType to
Value if it is statically known to have a single result - the cast in Python is
based on the op dynamically having a single result, and also handles the
multi-result case. This allows to build IR in a more concise way:
op = dialect.produce_multiple_results()
other = dialect.produce_single_result()
dialect.consume_multiple_results(other, op)
instead of having to access the results manually
op = dialect.produce.multiple_results()
other = dialect.produce_single_result()
dialect.consume_multiple_results(other.result, op.operation.results)
The dispatch is implemented directly in Python and is triggered automatically
for autogenerated OpView subclasses. Extension OpView classes should use the
functions provided in ods_common.py if they want to implement this behavior.
An alternative could be to implement the dispatch in the C++ bindings code, but
it would require to forward opaque types through all Python functions down to a
binding call, which makes it hard to inspect them in Python, e.g., to obtain
the types of values.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D111306
Added:
Modified:
mlir/python/mlir/dialects/_linalg_ops_ext.py
mlir/python/mlir/dialects/_ods_common.py
mlir/python/mlir/dialects/_scf_ops_ext.py
mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/test/python/dialects/linalg/ops.py
mlir/test/python/dialects/scf.py
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index 5360967492d5c..b7641c0a4b53c 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -10,6 +10,7 @@
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
+from ._ods_common import get_op_result_or_value as _get_op_result_or_value
def isa(cls: Type, ty: Type):
try:
@@ -26,11 +27,12 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
results = []
if isa(RankedTensorType, output.type):
results = [output.type]
- op = self.build_generic(results=results,
- operands=[value, output],
- attributes=None,
- loc=loc,
- ip=ip)
+ op = self.build_generic(
+ results=results,
+ operands=[_get_op_result_or_value(o) for o in [value, output]],
+ attributes=None,
+ loc=loc,
+ ip=ip)
OpView.__init__(self, op)
linalgDialect = Context.current.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, self.operation)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 2fbf3545f46d4..95c44186533f1 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -5,11 +5,14 @@
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
+from typing import Sequence as _Sequence, Union as _Union
__all__ = [
"equally_sized_accessor",
"extend_opview_class",
"get_default_loc_context",
+ "get_op_result_or_value",
+ "get_op_results_or_values",
"segmented_accessor",
]
@@ -118,3 +121,38 @@ def get_default_loc_context(location=None):
# Location.current raises ValueError if there is no current location.
return _cext.ir.Location.current.context
return location.context
+
+
+def get_op_result_or_value(
+ arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]
+) -> _cext.ir.Value:
+ """Returns the given value or the single result of the given op.
+
+ This is useful to implement op constructors so that they can take other ops as
+ arguments instead of requiring the caller to extract results for every op.
+ Raises ValueError if provided with an op that doesn't have a single result.
+ """
+ if isinstance(arg, _cext.ir.OpView):
+ return arg.operation.result
+ elif isinstance(arg, _cext.ir.Operation):
+ return arg.result
+ else:
+ assert isinstance(arg, _cext.ir.Value)
+ return arg
+
+
+def get_op_results_or_values(
+ arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]]
+) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
+ """Returns the given sequence of values or the results of the given op.
+
+ This is useful to implement op constructors so that they can take other ops as
+ lists of arguments instead of requiring the caller to extract results for
+ every op.
+ """
+ if isinstance(arg, _cext.ir.OpView):
+ return arg.operation.results
+ elif isinstance(arg, _cext.ir.Operation):
+ return arg.results
+ else:
+ return arg
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
index c6532a75632b5..a8924a7507a42 100644
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -7,8 +7,8 @@
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import Any, Sequence
-
+from typing import Any, Optional, Sequence, Union
+from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
class ForOp:
"""Specialization for the SCF for op class."""
@@ -17,7 +17,8 @@ def __init__(self,
lower_bound,
upper_bound,
step,
- iter_args: Sequence[Any] = [],
+ iter_args: Optional[Union[Operation, OpView,
+ Sequence[Value]]] = None,
*,
loc=None,
ip=None):
@@ -26,14 +27,22 @@ def __init__(self,
- `lower_bound` is the value to use as lower bound of the loop.
- `upper_bound` is the value to use as upper bound of the loop.
- `step` is the value to use as loop step.
- - `iter_args` is a list of additional loop-carried arguments.
+ - `iter_args` is a list of additional loop-carried arguments or an operation
+ producing them as results.
"""
+ if iter_args is None:
+ iter_args = []
+ iter_args = _get_op_results_or_values(iter_args)
+
results = [arg.type for arg in iter_args]
super().__init__(
self.build_generic(
regions=1,
results=results,
- operands=[lower_bound, upper_bound, step] + list(iter_args),
+ operands=[
+ _get_op_result_or_value(o)
+ for o in [lower_bound, upper_bound, step]
+ ] + list(iter_args),
loc=loc,
ip=ip))
self.regions[0].blocks.append(IndexType.get(), *results)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 047bde245b645..1acae7a7a389a 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Dict, List
+from typing import Dict, List, Sequence, Union
from contextlib import contextmanager
import functools
@@ -10,12 +10,15 @@
import threading
from ..... import ir
+from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
from .comprehension import *
from .config import *
from .emitter import *
_CONTEXT = threading.local()
+StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
+ Sequence[Union[ir.Value, ir.Operation, ir.OpView]]]
@contextmanager
def bind_op_def(model: LinalgOpDef):
@@ -37,6 +40,15 @@ def current_op_def() -> LinalgOpDef:
"but none is set. Did you mean to call this in an op definition?")
+def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
+ if isinstance(outs, (ir.Operation, ir.OpView)):
+ return _get_op_results_or_values(outs)
+ elif isinstance(outs, ir.OpResultList):
+ return outs
+
+ return [_get_op_result_or_value(o) for o in outs]
+
+
class DefinedOpCallable:
"""Callable that wraps any defined op function."""
@@ -44,7 +56,8 @@ def __init__(self, op_name: str, model: LinalgOpDef):
self.op_name = op_name
self.model = model
- def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs):
+ def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
+ outs: StructuredOpOuts, **kwargs):
"""Emits the corresponding op definition as IR.
Most arguments are passed through to the underlying emitter. The following
@@ -73,17 +86,19 @@ def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs):
emit_generic or not ctx.is_registered_operation(fully_qualified_name))
op_config = op_configs[0]
+ out_values = _prepare_structured_op_outs(outs)
+ in_values = [_get_op_result_or_value(i) for i in ins]
if op_config.structured_op:
if emit_generic:
return emit_generic_structured_op(
- op_config.structured_op, *ins, outs=outs, **kwargs)
+ op_config.structured_op, *in_values, outs=out_values, **kwargs)
else:
return emit_named_structured_op(
op_config.structured_op,
self.op_name,
self.model.metadata.cpp_class_name,
- *ins,
- outs=outs,
+ *in_values,
+ outs=out_values,
**kwargs)
raise NotImplementedError(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 7feea040aa77c..021fe83285945 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Dict, Sequence
+from typing import Dict, List, Sequence, Tuple, Union
from .....ir import *
from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
@@ -10,6 +10,7 @@
from .... import linalg
from .... import std
from .... import math
+from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
from .scalar_expr import *
from .config import *
@@ -18,8 +19,10 @@
__all__ = [
"emit_generic_structured_op",
"emit_named_structured_op",
+ "ValueList",
]
+ValueList = Union[Sequence[Value], OpResultList]
def isa(cls: Type, ty: Type):
try:
@@ -30,17 +33,18 @@ def isa(cls: Type, ty: Type):
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
- *ins: Value, outs: Sequence[Value],
+ *ins: Value, outs: ValueList,
**attrs: Sequence[int]):
all_arg_defs = op_config.ordered_operands
in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"]
out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"]
attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"]
- # Verify outs is a sequence.
- if not isinstance(outs, Sequence):
- raise ValueError(f"Expected named argument outs to have type Sequence "
- f"but got {type(outs)}")
+ # Verify outs is a sequence or a list of results.
+ if not isinstance(outs, (Sequence, OpResultList)):
+ raise ValueError(
+ f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}"
+ )
# Arity validation.
if len(ins) != len(in_arg_defs):
@@ -122,7 +126,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
- outs: Sequence[Value], **attrs: Sequence[int]):
+ outs: ValueList, **attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@@ -153,8 +157,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
- op_class_name: str, *ins: Value,
- outs: Sequence[Value], **attrs: Sequence[int]):
+ op_class_name: str, *ins: Value, outs: ValueList,
+ **attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@@ -355,11 +359,11 @@ def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
return std.MinUIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
-def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
- in_arg_defs: Sequence[OperandDefConfig],
- ins: Sequence[Value],
- out_arg_defs: Sequence[OperandDefConfig],
- outs: Sequence[Value]):
+def _infer_structured_outs(
+ op_config: LinalgStructuredOpConfig,
+ in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value],
+ out_arg_defs: Sequence[OperandDefConfig],
+ outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]:
"""Infers implicit outs and output types.
Respects existing contents of outs if not empty.
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 572c657336686..c3ee0c47aa05a 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -24,9 +24,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
- // CHECK: operands.append(variadic1)
- // CHECK: operands.append(non_variadic)
- // CHECK: if variadic2 is not None: operands.append(variadic2)
+ // CHECK: operands.append(_get_op_results_or_values(variadic1))
+ // CHECK: operands.append(_get_op_result_or_value(non_variadic))
+ // CHECK: if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -150,8 +150,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
- // CHECK: operands.append(_gen_arg_0)
- // CHECK: operands.append(_gen_arg_2)
+ // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
+ // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = is_
@@ -197,9 +197,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: results.append(i32)
// CHECK: results.append(_gen_res_1)
// CHECK: results.append(i64)
- // CHECK: operands.append(_gen_arg_0)
- // CHECK: operands.append(f32)
- // CHECK: operands.append(_gen_arg_2)
+ // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
+ // CHECK: operands.append(_get_op_result_or_value(f32))
+ // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -230,8 +230,8 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
- // CHECK: operands.append(non_variadic)
- // CHECK: operands.extend(variadic)
+ // CHECK: operands.append(_get_op_result_or_value(non_variadic))
+ // CHECK: operands.extend(_get_op_results_or_values(variadic))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -285,7 +285,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
- // CHECK: operands.append(in_)
+ // CHECK: operands.append(_get_op_result_or_value(in_))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -353,8 +353,8 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: attributes = {}
// CHECK: results.append(i64)
// CHECK: results.append(f64)
- // CHECK: operands.append(i32)
- // CHECK: operands.append(f32)
+ // CHECK: operands.append(_get_op_result_or_value(i32))
+ // CHECK: operands.append(_get_op_result_or_value(f32))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 8b990e66e13ec..6f07969cce856 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -185,3 +185,30 @@ def generic_form(lhs, rhs):
return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True)
print(module)
+
+
+# CHECK-LABEL: TEST: testOpResultFromOtherOp
+ at run
+def testOpResultFromOtherOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
+ f32))
+ def pass_an_op_directly(arg0, arg1):
+ one = std.ConstantOp(F32Type.get(), 1.0)
+ # CHECK: %[[LHS:.*]] = linalg.fill
+ lhs = linalg.FillOp(arg0, one)
+ # CHECK: %[[RHS:.*]] = linalg.fill
+ rhs = linalg.FillOp(arg1, one)
+ # CHECK: %[[INIT:.*]] = linalg.init_tensor
+ init = linalg.InitTensorOp([4, 8], f32)
+ # CHECK: linalg.matmul
+ # CHECK: ins(%[[LHS]], %[[RHS]]
+ # CHECK: outs(%[[INIT]]
+ return linalg.matmul(lhs, rhs, outs=init)
+
+ print(module)
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index d604913b1c4cb..7819679d90985 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -2,53 +2,82 @@
from mlir.ir import *
from mlir.dialects import scf
+from mlir.dialects import std
from mlir.dialects import builtin
-def run(f):
+def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
- f()
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
return f
# CHECK-LABEL: TEST: testSimpleLoop
- at run
+ at constructAndPrintInModule
def testSimpleLoop():
- with Context(), Location.unknown():
- module = Module.create()
- index_type = IndexType.get()
- with InsertionPoint(module.body):
+ index_type = IndexType.get()
- @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
- def simple_loop(lb, ub, step):
- loop = scf.ForOp(lb, ub, step, [lb, lb])
- with InsertionPoint(loop.body):
- scf.YieldOp(loop.inner_iter_args)
- return
+ @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
+ def simple_loop(lb, ub, step):
+ loop = scf.ForOp(lb, ub, step, [lb, lb])
+ with InsertionPoint(loop.body):
+ scf.YieldOp(loop.inner_iter_args)
+ return
- # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
- # CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
- # CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
- # CHECK: scf.yield %[[I1]], %[[I2]]
- print(module)
+
+# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
+# CHECK: scf.yield %[[I1]], %[[I2]]
# CHECK-LABEL: TEST: testInductionVar
- at run
+ at constructAndPrintInModule
def testInductionVar():
- with Context(), Location.unknown():
- module = Module.create()
- index_type = IndexType.get()
- with InsertionPoint(module.body):
+ index_type = IndexType.get()
+
+ @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
+ def induction_var(lb, ub, step):
+ loop = scf.ForOp(lb, ub, step, [lb])
+ with InsertionPoint(loop.body):
+ scf.YieldOp([loop.induction_variable])
+ return
+
+
+# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+# CHECK: scf.yield %[[IV]]
+
+
+ at constructAndPrintInModule
+def testOpsAsArguments():
+ index_type = IndexType.get()
+ callee = builtin.FuncOp(
+ "callee", ([], [index_type, index_type]), visibility="private")
+ func = builtin.FuncOp("ops_as_arguments", ([], []))
+ with InsertionPoint(func.add_entry_block()):
+ lb = std.ConstantOp.create_index(0)
+ ub = std.ConstantOp.create_index(42)
+ step = std.ConstantOp.create_index(2)
+ iter_args = std.CallOp(callee, [])
+ loop = scf.ForOp(lb, ub, step, iter_args)
+ with InsertionPoint(loop.body):
+ scf.YieldOp(loop.inner_iter_args)
+ std.ReturnOp([])
+
- @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
- def induction_var(lb, ub, step):
- loop = scf.ForOp(lb, ub, step, [lb])
- with InsertionPoint(loop.body):
- scf.YieldOp([loop.induction_variable])
- return
-
- # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
- # CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
- # CHECK: scf.yield %[[IV]]
- print(module)
+# CHECK-LABEL: TEST: testOpsAsArguments
+# CHECK: func private @callee() -> (index, index)
+# CHECK: func @ops_as_arguments() {
+# CHECK: %[[LB:.*]] = constant 0
+# CHECK: %[[UB:.*]] = constant 42
+# CHECK: %[[STEP:.*]] = constant 2
+# CHECK: %[[ARGS:.*]]:2 = call @callee()
+# CHECK: scf.for %arg0 = %c0 to %c42 step %c2
+# CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
+# CHECK: scf.yield %{{.*}}, %{{.*}}
+# CHECK: return
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 742cad748ca46..51993bbf6b051 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -28,7 +28,7 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context
+from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
_ods_ir = _ods_cext.ir
try:
@@ -489,20 +489,25 @@ constexpr const char *initTemplate = R"Py(
)Py";
/// Template for appending a single element to the operand/result list.
-/// {0} is either 'operand' or 'result';
-/// {1} is the field name.
-constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
+/// {0} is the field name.
+constexpr const char *singleOperandAppendTemplate =
+ "operands.append(_get_op_result_or_value({0}))";
+constexpr const char *singleResultAppendTemplate = "results.append({0})";
/// Template for appending an optional element to the operand/result list.
-/// {0} is either 'operand' or 'result';
-/// {1} is the field name.
-constexpr const char *optionalAppendTemplate =
- "if {1} is not None: {0}s.append({1})";
-
-/// Template for appending a a list of elements to the operand/result list.
-/// {0} is either 'operand' or 'result';
-/// {1} is the field name.
-constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})";
+/// {0} is the field name.
+constexpr const char *optionalAppendOperandTemplate =
+ "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
+constexpr const char *optionalAppendResultTemplate =
+ "if {0} is not None: results.append({0})";
+
+/// Template for appending a list of elements to the operand/result list.
+/// {0} is the field name.
+constexpr const char *multiOperandAppendTemplate =
+ "operands.extend(_get_op_results_or_values({0}))";
+constexpr const char *multiOperandAppendPackTemplate =
+ "operands.append(_get_op_results_or_values({0}))";
+constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for setting an attribute in the operation builder.
/// {0} is the attribute name;
@@ -625,43 +630,70 @@ static void populateBuilderLinesSuccessors(
}
/// Populates `builderLines` with additional lines that are required in the
-/// builder. `kind` must be either "operand" or "result". `names` contains the
-/// names of init arguments that correspond to the elements.
-static void populateBuilderLines(
- const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
- llvm::SmallVectorImpl<std::string> &builderLines,
- llvm::function_ref<int(const Operator &)> getNumElements,
- llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
- getElement) {
- bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
+/// builder to set up op operands.
+static void
+populateBuilderLinesOperand(const Operator &op,
+ llvm::ArrayRef<std::string> names,
+ llvm::SmallVectorImpl<std::string> &builderLines) {
+ bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
// For each element, find or generate a name.
- for (int i = 0, e = getNumElements(op); i < e; ++i) {
- const NamedTypeConstraint &element = getElement(op, i);
+ for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
+ const NamedTypeConstraint &element = op.getOperand(i);
+ std::string name = names[i];
+
+ // Choose the formatting string based on the element kind.
+ llvm::StringRef formatString;
+ if (!element.isVariableLength()) {
+ formatString = singleOperandAppendTemplate;
+ } else if (element.isOptional()) {
+ formatString = optionalAppendOperandTemplate;
+ } else {
+ assert(element.isVariadic() && "unhandled element group type");
+ // If emitting with sizedSegments, then we add the actual list-typed
+ // element. Otherwise, we extend the actual operands.
+ if (sizedSegments) {
+ formatString = multiOperandAppendPackTemplate;
+ } else {
+ formatString = multiOperandAppendTemplate;
+ }
+ }
+
+ builderLines.push_back(llvm::formatv(formatString.data(), name));
+ }
+}
+
+/// Populates `builderLines` with additional lines that are required in the
+/// builder to set up op results.
+static void
+populateBuilderLinesResult(const Operator &op,
+ llvm::ArrayRef<std::string> names,
+ llvm::SmallVectorImpl<std::string> &builderLines) {
+ bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
+
+ // For each element, find or generate a name.
+ for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+ const NamedTypeConstraint &element = op.getResult(i);
std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString;
if (!element.isVariableLength()) {
- formatString = singleElementAppendTemplate;
+ formatString = singleResultAppendTemplate;
} else if (element.isOptional()) {
- formatString = optionalAppendTemplate;
+ formatString = optionalAppendResultTemplate;
} else {
assert(element.isVariadic() && "unhandled element group type");
- // If emitting with sizedSegments, then we add the actual list typed
- // element using the singleElementAppendTemplate. Otherwise, we extend
- // the actual operands.
+ // If emitting with sizedSegments, then we add the actual list-typed
+ // element. Otherwise, we extend the actual operands.
if (sizedSegments) {
- // Append the list as is.
- formatString = singleElementAppendTemplate;
+ formatString = singleResultAppendTemplate;
} else {
- // Append the list elements.
- formatString = multiElementAppendTemplate;
+ formatString = multiResultAppendTemplate;
}
}
- // Add the lines.
- builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
+ builderLines.push_back(llvm::formatv(formatString.data(), name));
}
}
@@ -680,12 +712,10 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
op.getNumNativeAttributes() + op.getNumSuccessors());
populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
- populateBuilderLines(
- op, "result",
- llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
- builderLines, getNumResults, getResult);
- populateBuilderLines(op, "operand", operandArgNames, builderLines,
- getNumOperands, getOperand);
+ populateBuilderLinesResult(
+ op, llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
+ builderLines);
+ populateBuilderLinesOperand(op, operandArgNames, builderLines);
populateBuilderLinesAttr(
op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
builderLines);
More information about the Mlir-commits
mailing list