[Mlir-commits] [mlir] 27c6d55 - [mlir][python] generate value builders (#68308)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 9 14:16:33 PDT 2023
Author: Maksim Levental
Date: 2023-10-09T14:16:28-07:00
New Revision: 27c6d55cae74125b6381a647533090a72930ecda
URL: https://github.com/llvm/llvm-project/commit/27c6d55cae74125b6381a647533090a72930ecda
DIFF: https://github.com/llvm/llvm-project/commit/27c6d55cae74125b6381a647533090a72930ecda.diff
LOG: [mlir][python] generate value builders (#68308)
This PR adds the additional generation of what I'm calling "value
builders" (a term I'm not married to) that look like this:
```python
def empty(sizes, element_type, *, loc=None, ip=None):
return get_result_or_results(tensor.EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip))
```
which instantiates a `tensor.EmptyOp` and then immediately grabs the
result (`OpResult`) and then returns that *instead of a handle to the
op*.
What's the point of adding these when `EmptyOp.result` already exists?
My claim/feeling/intuition is that eDSL users are more comfortable with
a value centric programming model (i.e., passing values as operands) as
opposed to an operator instantiation programming model. Thus this change
enables (or at least goes towards) the bindings supporting such a user
and use case. For example,
```python
i32 = IntegerType.get_signless(32)
...
ten1 = tensor.empty((10, 10), i32)
ten2 = tensor.empty((10, 10), i32)
ten3 = arith.addi(ten1, ten2)
```
Note, in order to present a "pythonic" API and enable "pythonic" eDSLs,
the generated identifiers (op names and operand names) are snake case
instead of camel case and thus `llvm::convertToSnakeFromCamelCase`
needed a small fix. Thus this PR is stacked on top of
https://github.com/llvm/llvm-project/pull/68375.
In addition, as a kind of victory lap, this PR adds a "rangefor" that
looks and acts exactly like python's `range` but emits `scf.for`.
Added:
Modified:
mlir/python/mlir/dialects/_ods_common.py
mlir/python/mlir/dialects/_scf_ops_ext.py
mlir/python/mlir/dialects/scf.py
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/test/python/dialects/scf.py
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 7655629a5542520..895c3228139b392 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -13,6 +13,7 @@
"get_default_loc_context",
"get_op_result_or_value",
"get_op_results_or_values",
+ "get_op_result_or_op_results",
"segmented_accessor",
]
@@ -167,3 +168,17 @@ def get_op_results_or_values(
return arg.results
else:
return [get_op_result_or_value(element) for element in arg]
+
+
+def get_op_result_or_op_results(
+ op: _Union[_cext.ir.OpView, _cext.ir.Operation],
+) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
+ if isinstance(op, _cext.ir.OpView):
+ op = op.operation
+ return (
+ list(get_op_results_or_values(op))
+ if len(op.results) > 1
+ else get_op_result_or_value(op)
+ if len(op.results) > 0
+ else op
+ )
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
index 4b0a31327abb0ee..89cc8a19895c7b4 100644
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -7,7 +7,8 @@
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import Any, Optional, Sequence, Union
+from typing import Optional, Sequence, Union
+
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
@@ -25,7 +26,7 @@ def __init__(
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
- ip=None
+ ip=None,
):
"""Creates an SCF `for` operation.
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 302a49d56c211a1..49685ca2271fc61 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -2,4 +2,42 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional, Sequence
+
from ._scf_ops_gen import *
+from .arith import constant
+from ..ir import *
+
+
+def for_(
+ start,
+ stop=None,
+ step=None,
+ iter_args: Optional[Sequence[Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+):
+ if step is None:
+ step = 1
+ if stop is None:
+ stop = start
+ start = 0
+ params = [start, stop, step]
+ for i, p in enumerate(params):
+ if isinstance(p, int):
+ p = constant(p)
+ elif isinstance(p, float):
+ raise ValueError(f"{p=} must be int.")
+ params[i] = p
+
+ for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
+ iv = for_op.induction_variable
+ iter_args = tuple(for_op.inner_iter_args)
+ with InsertionPoint(for_op.body):
+ if len(iter_args) > 1:
+ yield iv, iter_args
+ elif len(iter_args) == 1:
+ yield iv, iter_args[0]
+ else:
+ yield iv
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index a131209fa45cb6c..8ca23fa9f45c4ab 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -60,6 +60,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
Optional<AnyType>:$variadic2);
}
+// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
@@ -104,6 +107,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
Variadic<AnyType>:$variadic2);
}
+// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOp(_ods_ir.OpView):
@@ -151,6 +157,9 @@ def AttributedOp : TestOp<"attributed_op"> {
UnitAttr:$unitAttr, I32Attr:$in);
}
+// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
@@ -184,6 +193,9 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
}
+// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs"
@@ -205,6 +217,9 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
let results = (outs);
}
+// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, type_, *, loc=None, ip=None):
@@ -220,6 +235,9 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
let results = (outs AnyType:$res, AnyType);
}
+// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
@@ -227,6 +245,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
let results = (outs AnyType:$res, Variadic<AnyType>);
}
+// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.empty"
@@ -241,6 +262,8 @@ def EmptyOp : TestOp<"empty">;
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
+// CHECK: def empty(*, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -252,6 +275,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
let results = (outs I32:$i32, F32:$f32);
}
+// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
// CHECK: def __init__(self, *, loc=None, ip=None):
@@ -262,6 +288,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
let results = (outs AnyType, AnyType, AnyType);
}
+// CHECK: def infer_result_types_op(*, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.missing_names"
@@ -297,6 +326,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
+// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
@@ -323,9 +355,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: @builtins.property
// CHECK: def optional(self):
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
-
}
+// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
@@ -355,6 +389,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
}
+// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
@@ -385,6 +422,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
}
+// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
@@ -405,6 +445,10 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: return self.operation.operands[0]
let arguments = (ins AnyType:$in);
}
+
+// CHECK: def python_keyword(in_, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
// CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
@@ -416,6 +460,9 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
let results = (outs AnyType:$res);
}
+// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
// CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None):
@@ -423,6 +470,9 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
let results = (outs Variadic<AnyType>:$res);
}
+// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView):
@@ -447,6 +497,9 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
Variadic<AnyType>:$variadic2);
}
+// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result"
@@ -470,6 +523,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
Variadic<AnyType>:$variadic2);
}
+// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SimpleOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.simple"
@@ -507,6 +563,9 @@ def SimpleOp : TestOp<"simple"> {
let results = (outs I64:$i64, AnyFloat:$f64);
}
+// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
@@ -531,6 +590,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
// CHECK: return self.regions[2:]
}
+// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+
// CHECK: class VariadicRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
def VariadicRegionOp : TestOp<"variadic_region"> {
@@ -551,6 +613,9 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: return self.regions[0:]
}
+// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.with_successors"
@@ -562,3 +627,6 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
let successors = (successor AnySuccessor:$successor,
VariadicSuccessor<AnySuccessor>:$successors);
}
+
+// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
+// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 8cb55fdf6a1eb3b..414307d8191513b 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -4,7 +4,6 @@
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import scf
-from mlir.dialects import builtin
def constructAndPrintInModule(f):
@@ -54,6 +53,28 @@ def induction_var(lb, ub, step):
# CHECK: scf.yield %[[IV]]
+# CHECK-LABEL: TEST: testForSugar
+ at constructAndPrintInModule
+def testForSugar():
+ index_type = IndexType.get()
+ range = scf.for_
+
+ @func.FuncOp.from_py_func(index_type, index_type, index_type)
+ def range_loop(lb, ub, step):
+ for i in range(lb, ub, step):
+ add = arith.addi(i, i)
+ scf.yield_([])
+ return
+
+
+# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
+# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+# CHECK: %0 = arith.addi %[[IV]], %[[IV]] : index
+# CHECK: }
+# CHECK: return
+# CHECK: }
+
+
@constructAndPrintInModule
def testOpsAsArguments():
index_type = IndexType.get()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0b5df7ab70dddb2..fc094a1829ff755 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,7 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
_ods_ir = _ods_cext.ir
try:
@@ -39,6 +39,7 @@ except ImportError:
_ods_ext_module = None
import builtins
+from typing import Sequence as _Sequence, Union as _Union
)Py";
@@ -260,11 +261,16 @@ constexpr const char *attributeDeleterTemplate = R"Py(
del self.operation.attributes["{1}"]
)Py";
-constexpr const char *regionAccessorTemplate = R"PY(
+constexpr const char *regionAccessorTemplate = R"Py(
@builtins.property
def {0}(self):
return self.regions[{1}]
-)PY";
+)Py";
+
+constexpr const char *valueBuilderTemplate = R"Py(
+def {0}({2}) -> {4}:
+ return _get_op_result_or_op_results({1}({3}))
+)Py";
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
@@ -609,9 +615,7 @@ populateBuilderArgsResults(const Operator &op,
static void
populateBuilderArgs(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs,
- llvm::SmallVectorImpl<std::string> &operandNames,
- llvm::SmallVectorImpl<std::string> &successorArgNames) {
-
+ llvm::SmallVectorImpl<std::string> &operandNames) {
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
std::string name = op.getArgName(i).str();
if (name.empty())
@@ -734,11 +738,11 @@ populateBuilderLinesOperand(const Operator &op,
/// attribute:
/// - {0} is the name of the attribute from which to derive the types.
constexpr const char *deriveTypeFromAttrTemplate =
- R"PY(_ods_result_type_source_attr = attributes["{0}"]
+ R"Py(_ods_result_type_source_attr = attributes["{0}"]
_ods_derived_result_type = (
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
- _ods_result_type_source_attr.type))PY";
+ _ods_result_type_source_attr.type))Py";
/// Python code template appending {0} type {1} times to the results list.
constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
@@ -836,11 +840,14 @@ populateBuilderRegions(const Operator &op,
}
/// Emits a default builder constructing an operation from the list of its
-/// result types, followed by a list of its operands.
-static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
+/// result types, followed by a list of its operands. Returns vector
+/// of fully built functionArgs for downstream users (to save having to
+/// rebuild anew).
+static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
+ raw_ostream &os) {
// If we are asked to skip default builders, comply.
if (op.skipDefaultBuilders())
- return;
+ return {};
llvm::SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines;
@@ -850,7 +857,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
op.getNumNativeAttributes() + op.getNumSuccessors());
populateBuilderArgsResults(op, builderArgs);
size_t numResultArgs = builderArgs.size();
- populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
+ populateBuilderArgs(op, builderArgs, operandArgNames);
size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
@@ -921,6 +928,8 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
llvm::join(builderLines, "\n "),
llvm::join(initArgs, ", "));
+ return llvm::to_vector<8>(
+ llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); }));
}
static void emitSegmentSpec(
@@ -968,6 +977,45 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
}
}
+/// Emits builder that extracts results from op
+static void emitValueBuilder(const Operator &op,
+ llvm::SmallVector<std::string> functionArgs,
+ raw_ostream &os) {
+ // If we are asked to skip default builders, comply.
+ if (op.skipDefaultBuilders())
+ return;
+ auto name = sanitizeName(op.getOperationName());
+ iterator_range<llvm::SplittingIterator> splitName = llvm::split(name, ".");
+ // Params with (possibly) default args.
+ auto valueBuilderParams =
+ llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
+ llvm::SmallVector<llvm::StringRef> argMaybeDefault =
+ llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
+ auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
+ if (argMaybeDefault.size() == 2)
+ return arg + "=" + argMaybeDefault[1].str();
+ return arg;
+ });
+ // Actual args passed to op builder (e.g., opParam=op_param).
+ auto opBuilderArgs = llvm::map_range(
+ llvm::make_filter_range(functionArgs,
+ [](const std::string &s) { return s != "*"; }),
+ [](const std::string &arg) {
+ auto lhs = *llvm::split(arg, "=").begin();
+ return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
+ });
+ os << llvm::formatv(
+ valueBuilderTemplate,
+ // Drop dialect name and then sanitize again (to catch e.g. func.return).
+ sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
+ op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
+ llvm::join(opBuilderArgs, ", "),
+ (op.getNumResults() > 1
+ ? "_Sequence[_ods_ir.OpResult]"
+ : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+ : "_ods_ir.Operation")));
+}
+
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) {
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
@@ -982,11 +1030,12 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
}
emitRegionAttributes(op, os);
- emitDefaultOpBuilder(op, os);
+ llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
emitOperandAccessors(op, os);
emitAttributeAccessors(op, os);
emitResultAccessors(op, os);
emitRegionAccessors(op, os);
+ emitValueBuilder(op, functionArgs, os);
}
/// Emits bindings for the dialect specified in the command line, including file
More information about the Mlir-commits
mailing list