[llvm-branch-commits] [mlir] 71b6b01 - [mlir][python] Factor out standalone OpView._ods_build_default class method.
Stella Laurenzo via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 19 09:38:52 PST 2021
Author: Stella Laurenzo
Date: 2021-01-19T09:29:57-08:00
New Revision: 71b6b010e6bc49caaec511195e33ac1f43f07c64
URL: https://github.com/llvm/llvm-project/commit/71b6b010e6bc49caaec511195e33ac1f43f07c64
DIFF: https://github.com/llvm/llvm-project/commit/71b6b010e6bc49caaec511195e33ac1f43f07c64.diff
LOG: [mlir][python] Factor out standalone OpView._ods_build_default class method.
* This allows us to hoist trait level information for regions and sized-variadic to class level attributes (_ODS_REGIONS, _ODS_OPERAND_SEGMENTS, _ODS_RESULT_SEGMENTS).
* Eliminates some splicey python generated code in favor of a native helper for it.
* Makes it possible to implement custom, variadic and region based builders with one line of python, without needing to manually code access to the segment attributes.
* Needs follow-on work for region based callbacks and support for SingleBlockImplicitTerminator.
* A follow-up will actually add ODS support for generating custom Python builders that delegate to this new method.
* Also includes the start of an e2e sample for constructing linalg ops where this limitation was discovered (working progressively through this example and cleaning up as I go).
Differential Revision: https://reviews.llvm.org/D94738
Added:
mlir/examples/python/linalg_matmul.py
mlir/test/Bindings/Python/ods_helpers.py
Modified:
mlir/docs/Bindings/Python.md
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index b5595bc7010e..6bb9e7ebe2f6 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -365,7 +365,7 @@ for the canonical way to use this facility.
Each dialect with a mapping to python requires that an appropriate
`{DIALECT_NAMESPACE}.py` wrapper module is created. This is done by invoking
-`mlir-tablegen` on a python-bindings specific tablegen wrapper that includes
+`mlir-tblgen` on a python-bindings specific tablegen wrapper that includes
the boilerplate and actual dialect specific `td` file. An example, for the
`StandardOps` (which is assigned the namespace `std` as a special case):
@@ -383,7 +383,7 @@ In the main repository, building the wrapper is done via the CMake function
`add_mlir_dialect_python_bindings`, which invokes:
```
-mlir-tablegen -gen-python-op-bindings -bind-dialect={DIALECT_NAMESPACE} \
+mlir-tblgen -gen-python-op-bindings -bind-dialect={DIALECT_NAMESPACE} \
{PYTHON_BINDING_TD_FILE}
```
@@ -411,7 +411,8 @@ The wrapper module tablegen emitter outputs:
Note: In order to avoid naming conflicts, all internal names used by the wrapper
module are prefixed by `_ods_`.
-Each concrete `OpView` subclass further defines several attributes:
+Each concrete `OpView` subclass further defines several public-intended
+attributes:
* `OPERATION_NAME` attribute with the `str` fully qualified operation name
(i.e. `std.absf`).
@@ -421,6 +422,20 @@ Each concrete `OpView` subclass further defines several attributes:
for unnamed of each).
* `@property` getter, setter and deleter for each declared attribute.
+It further emits additional private-intended attributes meant for subclassing
+and customization (default cases omit these attributes in favor of the
+defaults on `OpView`):
+
+* `_ODS_REGIONS`: A specification on the number and types of regions.
+ Currently a tuple of (min_region_count, has_no_variadic_regions). Note that
+ the API does some light validation on this but the primary purpose is to
+ capture sufficient information to perform other default building and region
+ accessor generation.
+* `_ODS_OPERAND_SEGMENTS` and `_ODS_RESULT_SEGMENTS`: Black-box value which
+ indicates the structure of either the operand or results with respect to
+ variadics. Used by `OpView._ods_build_default` to decode operand and result
+ lists that contain lists.
+
#### Builders
Presently, only a single, default builder is mapped to the `__init__` method.
diff --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py
new file mode 100644
index 000000000000..83dc15eda9b6
--- /dev/null
+++ b/mlir/examples/python/linalg_matmul.py
@@ -0,0 +1,73 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This is a work in progress example to do end2end build and code generation
+# of a small linalg program with configuration options. It is currently non
+# functional and is being used to elaborate the APIs.
+
+from typing import Tuple
+
+from mlir.ir import *
+from mlir.dialects import linalg
+from mlir.dialects import std
+
+
+# TODO: This should be in the core API.
+def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
+ """Creates a |func| op.
+ TODO: This should really be in the MLIR API.
+ Returns:
+ (operation, entry_block)
+ """
+ attrs = {
+ "type": TypeAttr.get(func_type),
+ "sym_name": StringAttr.get(name),
+ }
+ op = Operation.create("func", regions=1, attributes=attrs)
+ body_region = op.regions[0]
+ entry_block = body_region.blocks.append(*func_type.inputs)
+ return op, entry_block
+
+
+# TODO: Generate customs builder vs patching one in.
+def PatchMatmulOpInit(self, lhs, rhs, result, loc=None, ip=None):
+ super(linalg.MatmulOp, self).__init__(
+ self._ods_build_default(operands=[[lhs, rhs], [result]],
+ results=[],
+ loc=loc,
+ ip=ip))
+ # TODO: Implement support for SingleBlockImplicitTerminator
+ block = self.regions[0].blocks.append()
+ with InsertionPoint(block):
+ linalg.YieldOp(values=[])
+
+linalg.MatmulOp.__init__ = PatchMatmulOpInit
+
+
+def build_matmul_func(func_name, m, k, n, dtype):
+ lhs_type = MemRefType.get(dtype, [m, k])
+ rhs_type = MemRefType.get(dtype, [k, n])
+ result_type = MemRefType.get(dtype, [m, n])
+ # TODO: There should be a one-liner for this.
+ func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
+ _, entry = FuncOp(func_name, func_type)
+ lhs, rhs, result = entry.arguments
+ with InsertionPoint(entry):
+ linalg.MatmulOp(lhs, rhs, result)
+ std.ReturnOp([])
+
+
+def run():
+ with Context() as c, Location.unknown():
+ module = Module.create()
+ # TODO: This at_block_terminator vs default construct distinction feels
+ # wrong and is error-prone.
+ with InsertionPoint.at_block_terminator(module.body):
+ build_matmul_func('main', 18, 32, 96, F32Type.get())
+
+ print(module)
+ print(module.operation.get_asm(print_generic_op_form=True))
+
+
+if __name__ == '__main__': run()
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 493ea5c1e47a..63bdd0c7a184 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -130,6 +130,13 @@ equivalent to printing the operation that produced it.
// Utilities.
//------------------------------------------------------------------------------
+// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+py::object classmethod(Func f, Args... args) {
+ py::object cf = py::cpp_function(f, args...);
+ return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
+}
+
/// Checks whether the given type is an integer or float type.
static int mlirTypeIsAIntegerOrFloat(MlirType type) {
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
@@ -1027,6 +1034,267 @@ py::object PyOperation::createOpView() {
return py::cast(PyOpView(getRef().getObject()));
}
+//------------------------------------------------------------------------------
+// PyOpView
+//------------------------------------------------------------------------------
+
+py::object
+PyOpView::odsBuildDefault(py::object cls, py::list operandList,
+ py::list resultTypeList,
+ llvm::Optional<py::dict> attributes,
+ llvm::Optional<std::vector<PyBlock *>> successors,
+ llvm::Optional<int> regions,
+ DefaultingPyLocation location, py::object maybeIp) {
+ PyMlirContextRef context = location->getContext();
+ // Class level operation construction metadata.
+ std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
+ // Operand and result segment specs are either none, which does no
+ // variadic unpacking, or a list of ints with segment sizes, where each
+ // element is either a positive number (typically 1 for a scalar) or -1 to
+ // indicate that it is derived from the length of the same-indexed operand
+ // or result (implying that it is a list at that position).
+ py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
+ py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
+
+ std::vector<uint64_t> operandSegmentLengths;
+ std::vector<uint64_t> resultSegmentLengths;
+
+ // Validate/determine region count.
+ auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
+ int opMinRegionCount = std::get<0>(opRegionSpec);
+ bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
+ if (!regions) {
+ regions = opMinRegionCount;
+ }
+ if (*regions < opMinRegionCount) {
+ throw py::value_error(
+ (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
+ llvm::Twine(opMinRegionCount) +
+ " regions but was built with regions=" + llvm::Twine(*regions))
+ .str());
+ }
+ if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
+ throw py::value_error(
+ (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
+ llvm::Twine(opMinRegionCount) +
+ " regions but was built with regions=" + llvm::Twine(*regions))
+ .str());
+ }
+
+ // Unpack results.
+ std::vector<PyType *> resultTypes;
+ resultTypes.reserve(resultTypeList.size());
+ if (resultSegmentSpecObj.is_none()) {
+ // Non-variadic result unpacking.
+ for (auto it : llvm::enumerate(resultTypeList)) {
+ try {
+ resultTypes.push_back(py::cast<PyType *>(it.value()));
+ if (!resultTypes.back())
+ throw py::cast_error();
+ } catch (py::cast_error &err) {
+ throw py::value_error((llvm::Twine("Result ") +
+ llvm::Twine(it.index()) + " of operation \"" +
+ name + "\" must be a Type (" + err.what() + ")")
+ .str());
+ }
+ }
+ } else {
+ // Sized result unpacking.
+ auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
+ if (resultSegmentSpec.size() != resultTypeList.size()) {
+ throw py::value_error((llvm::Twine("Operation \"") + name +
+ "\" requires " +
+ llvm::Twine(resultSegmentSpec.size()) +
+ "result segments but was provided " +
+ llvm::Twine(resultTypeList.size()))
+ .str());
+ }
+ resultSegmentLengths.reserve(resultTypeList.size());
+ for (auto it :
+ llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
+ int segmentSpec = std::get<1>(it.value());
+ if (segmentSpec == 1 || segmentSpec == 0) {
+ // Unpack unary element.
+ try {
+ auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
+ if (resultType) {
+ resultTypes.push_back(resultType);
+ resultSegmentLengths.push_back(1);
+ } else if (segmentSpec == 0) {
+ // Allowed to be optional.
+ resultSegmentLengths.push_back(0);
+ } else {
+ throw py::cast_error("was None and result is not optional");
+ }
+ } catch (py::cast_error &err) {
+ throw py::value_error((llvm::Twine("Result ") +
+ llvm::Twine(it.index()) + " of operation \"" +
+ name + "\" must be a Type (" + err.what() +
+ ")")
+ .str());
+ }
+ } else if (segmentSpec == -1) {
+ // Unpack sequence by appending.
+ try {
+ if (std::get<0>(it.value()).is_none()) {
+ // Treat it as an empty list.
+ resultSegmentLengths.push_back(0);
+ } else {
+ // Unpack the list.
+ auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
+ for (py::object segmentItem : segment) {
+ resultTypes.push_back(py::cast<PyType *>(segmentItem));
+ if (!resultTypes.back()) {
+ throw py::cast_error("contained a None item");
+ }
+ }
+ resultSegmentLengths.push_back(segment.size());
+ }
+ } catch (std::exception &err) {
+ // NOTE: Sloppy to be using a catch-all here, but there are at least
+ // three
diff erent unrelated exceptions that can be thrown in the
+ // above "casts". Just keep the scope above small and catch them all.
+ throw py::value_error((llvm::Twine("Result ") +
+ llvm::Twine(it.index()) + " of operation \"" +
+ name + "\" must be a Sequence of Types (" +
+ err.what() + ")")
+ .str());
+ }
+ } else {
+ throw py::value_error("Unexpected segment spec");
+ }
+ }
+ }
+
+ // Unpack operands.
+ std::vector<PyValue *> operands;
+ operands.reserve(operands.size());
+ if (operandSegmentSpecObj.is_none()) {
+ // Non-sized operand unpacking.
+ for (auto it : llvm::enumerate(operandList)) {
+ try {
+ operands.push_back(py::cast<PyValue *>(it.value()));
+ if (!operands.back())
+ throw py::cast_error();
+ } catch (py::cast_error &err) {
+ throw py::value_error((llvm::Twine("Operand ") +
+ llvm::Twine(it.index()) + " of operation \"" +
+ name + "\" must be a Value (" + err.what() + ")")
+ .str());
+ }
+ }
+ } else {
+ // Sized operand unpacking.
+ auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
+ if (operandSegmentSpec.size() != operandList.size()) {
+ throw py::value_error((llvm::Twine("Operation \"") + name +
+ "\" requires " +
+ llvm::Twine(operandSegmentSpec.size()) +
+ "operand segments but was provided " +
+ llvm::Twine(operandList.size()))
+ .str());
+ }
+ operandSegmentLengths.reserve(operandList.size());
+ for (auto it :
+ llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
+ int segmentSpec = std::get<1>(it.value());
+ if (segmentSpec == 1 || segmentSpec == 0) {
+ // Unpack unary element.
+ try {
+ auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
+ if (operandValue) {
+ operands.push_back(operandValue);
+ operandSegmentLengths.push_back(1);
+ } else if (segmentSpec == 0) {
+ // Allowed to be optional.
+ operandSegmentLengths.push_back(0);
+ } else {
+ throw py::cast_error("was None and operand is not optional");
+ }
+ } catch (py::cast_error &err) {
+ throw py::value_error((llvm::Twine("Operand ") +
+ llvm::Twine(it.index()) + " of operation \"" +
+ name + "\" must be a Value (" + err.what() +
+ ")")
+ .str());
+ }
+ } else if (segmentSpec == -1) {
+ // Unpack sequence by appending.
+ try {
+ if (std::get<0>(it.value()).is_none()) {
+ // Treat it as an empty list.
+ operandSegmentLengths.push_back(0);
+ } else {
+ // Unpack the list.
+ auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
+ for (py::object segmentItem : segment) {
+ operands.push_back(py::cast<PyValue *>(segmentItem));
+ if (!operands.back()) {
+ throw py::cast_error("contained a None item");
+ }
+ }
+ operandSegmentLengths.push_back(segment.size());
+ }
+ } catch (std::exception &err) {
+ // NOTE: Sloppy to be using a catch-all here, but there are at least
+ // three
diff erent unrelated exceptions that can be thrown in the
+ // above "casts". Just keep the scope above small and catch them all.
+ throw py::value_error((llvm::Twine("Operand ") +
+ llvm::Twine(it.index()) + " of operation \"" +
+ name + "\" must be a Sequence of Values (" +
+ err.what() + ")")
+ .str());
+ }
+ } else {
+ throw py::value_error("Unexpected segment spec");
+ }
+ }
+ }
+
+ // Merge operand/result segment lengths into attributes if needed.
+ if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
+ // Dup.
+ if (attributes) {
+ attributes = py::dict(*attributes);
+ } else {
+ attributes = py::dict();
+ }
+ if (attributes->contains("result_segment_sizes") ||
+ attributes->contains("operand_segment_sizes")) {
+ throw py::value_error("Manually setting a 'result_segment_sizes' or "
+ "'operand_segment_sizes' attribute is unsupported. "
+ "Use Operation.create for such low-level access.");
+ }
+
+ // Add result_segment_sizes attribute.
+ if (!resultSegmentLengths.empty()) {
+ int64_t size = resultSegmentLengths.size();
+ MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
+ mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
+ resultSegmentLengths.size(), resultSegmentLengths.data());
+ (*attributes)["result_segment_sizes"] =
+ PyAttribute(context, segmentLengthAttr);
+ }
+
+ // Add operand_segment_sizes attribute.
+ if (!operandSegmentLengths.empty()) {
+ int64_t size = operandSegmentLengths.size();
+ MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
+ mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
+ operandSegmentLengths.size(), operandSegmentLengths.data());
+ (*attributes)["operand_segment_sizes"] =
+ PyAttribute(context, segmentLengthAttr);
+ }
+ }
+
+ // Delegate to create.
+ return PyOperation::create(std::move(name), /*operands=*/std::move(operands),
+ /*results=*/std::move(resultTypes),
+ /*attributes=*/std::move(attributes),
+ /*successors=*/std::move(successors),
+ /*regions=*/*regions, location, maybeIp);
+}
+
PyOpView::PyOpView(py::object operationObject)
// Casting through the PyOperationBase base-class and then back to the
// Operation lets us accept any PyOperationBase subclass.
@@ -3397,17 +3665,29 @@ void mlir::python::populateIRSubmodule(py::module &m) {
"Context that owns the Operation")
.def_property_readonly("opview", &PyOperation::createOpView);
- py::class_<PyOpView, PyOperationBase>(m, "OpView")
- .def(py::init<py::object>())
- .def_property_readonly("operation", &PyOpView::getOperationObject)
- .def_property_readonly(
- "context",
- [](PyOpView &self) {
- return self.getOperation().getContext().getObject();
- },
- "Context that owns the Operation")
- .def("__str__",
- [](PyOpView &self) { return py::str(self.getOperationObject()); });
+ auto opViewClass =
+ py::class_<PyOpView, PyOperationBase>(m, "OpView")
+ .def(py::init<py::object>())
+ .def_property_readonly("operation", &PyOpView::getOperationObject)
+ .def_property_readonly(
+ "context",
+ [](PyOpView &self) {
+ return self.getOperation().getContext().getObject();
+ },
+ "Context that owns the Operation")
+ .def("__str__", [](PyOpView &self) {
+ return py::str(self.getOperationObject());
+ });
+ opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
+ opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
+ opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
+ opViewClass.attr("_ods_build_default") = classmethod(
+ &PyOpView::odsBuildDefault, py::arg("cls"),
+ py::arg("operands") = py::none(), py::arg("results") = py::none(),
+ py::arg("attributes") = py::none(), py::arg("successors") = py::none(),
+ py::arg("regions") = py::none(), py::arg("loc") = py::none(),
+ py::arg("ip") = py::none(),
+ "Builds a specific, generated OpView based on class level attributes.");
//----------------------------------------------------------------------------
// Mapping of PyRegion.
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index e789f536a829..443cdd691862 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -497,6 +497,14 @@ class PyOpView : public PyOperationBase {
pybind11::object getOperationObject() { return operationObject; }
+ static pybind11::object
+ odsBuildDefault(pybind11::object cls, pybind11::list operandList,
+ pybind11::list resultTypeList,
+ llvm::Optional<pybind11::dict> attributes,
+ llvm::Optional<std::vector<PyBlock *>> successors,
+ llvm::Optional<int> regions, DefaultingPyLocation location,
+ pybind11::object maybeIp);
+
private:
PyOperation &operation; // For efficient, cast-free access from C++
pybind11::object operationObject; // Holds the reference.
diff --git a/mlir/test/Bindings/Python/ods_helpers.py b/mlir/test/Bindings/Python/ods_helpers.py
new file mode 100644
index 000000000000..1db1112c4087
--- /dev/null
+++ b/mlir/test/Bindings/Python/ods_helpers.py
@@ -0,0 +1,210 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+
+def add_dummy_value():
+ return Operation.create(
+ "custom.value",
+ results=[IntegerType.get_signless(32)]).result
+
+
+def testOdsBuildDefaultImplicitRegions():
+
+ class TestFixedRegionsOp(OpView):
+ OPERATION_NAME = "custom.test_op"
+ _ODS_REGIONS = (2, True)
+
+ class TestVariadicRegionsOp(OpView):
+ OPERATION_NAME = "custom.test_any_regions_op"
+ _ODS_REGIONS = (2, False)
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint.at_block_terminator(m.body):
+ op = TestFixedRegionsOp._ods_build_default(operands=[], results=[])
+ # CHECK: NUM_REGIONS: 2
+ print(f"NUM_REGIONS: {len(op.regions)}")
+ # Including a regions= that matches should be fine.
+ op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=2)
+ print(f"NUM_REGIONS: {len(op.regions)}")
+ # Reject greater than.
+ try:
+ op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=3)
+ except ValueError as e:
+ # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
+ print(f"ERROR:{e}")
+ # Reject less than.
+ try:
+ op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=1)
+ except ValueError as e:
+ # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
+ print(f"ERROR:{e}")
+
+ # If no regions specified for a variadic region op, build the minimum.
+ op = TestVariadicRegionsOp._ods_build_default(operands=[], results=[])
+ # CHECK: DEFAULT_NUM_REGIONS: 2
+ print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
+ # Should also accept an explicit regions= that matches the minimum.
+ op = TestVariadicRegionsOp._ods_build_default(
+ operands=[], results=[], regions=2)
+ # CHECK: EQ_NUM_REGIONS: 2
+ print(f"EQ_NUM_REGIONS: {len(op.regions)}")
+ # And accept greater than minimum.
+ # Should also accept an explicit regions= that matches the minimum.
+ op = TestVariadicRegionsOp._ods_build_default(
+ operands=[], results=[], regions=3)
+ # CHECK: GT_NUM_REGIONS: 3
+ print(f"GT_NUM_REGIONS: {len(op.regions)}")
+ # Should reject less than minimum.
+ try:
+ op = TestVariadicRegionsOp._ods_build_default(operands=[], results=[], regions=1)
+ except ValueError as e:
+ # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
+ print(f"ERROR:{e}")
+
+
+
+run(testOdsBuildDefaultImplicitRegions)
+
+
+def testOdsBuildDefaultNonVariadic():
+
+ class TestOp(OpView):
+ OPERATION_NAME = "custom.test_op"
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint.at_block_terminator(m.body):
+ v0 = add_dummy_value()
+ v1 = add_dummy_value()
+ t0 = IntegerType.get_signless(8)
+ t1 = IntegerType.get_signless(16)
+ op = TestOp._ods_build_default(operands=[v0, v1], results=[t0, t1])
+ # CHECK: %[[V0:.+]] = "custom.value"
+ # CHECK: %[[V1:.+]] = "custom.value"
+ # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
+ # CHECK-NOT: operand_segment_sizes
+ # CHECK-NOT: result_segment_sizes
+ # CHECK-SAME: : (i32, i32) -> (i8, i16)
+ print(m)
+
+run(testOdsBuildDefaultNonVariadic)
+
+
+def testOdsBuildDefaultSizedVariadic():
+
+ class TestOp(OpView):
+ OPERATION_NAME = "custom.test_op"
+ _ODS_OPERAND_SEGMENTS = [1, -1, 0]
+ _ODS_RESULT_SEGMENTS = [-1, 0, 1]
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint.at_block_terminator(m.body):
+ v0 = add_dummy_value()
+ v1 = add_dummy_value()
+ v2 = add_dummy_value()
+ v3 = add_dummy_value()
+ t0 = IntegerType.get_signless(8)
+ t1 = IntegerType.get_signless(16)
+ t2 = IntegerType.get_signless(32)
+ t3 = IntegerType.get_signless(64)
+ # CHECK: %[[V0:.+]] = "custom.value"
+ # CHECK: %[[V1:.+]] = "custom.value"
+ # CHECK: %[[V2:.+]] = "custom.value"
+ # CHECK: %[[V3:.+]] = "custom.value"
+ # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
+ # CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi64>
+ # CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi64>
+ # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
+ op = TestOp._ods_build_default(
+ operands=[v0, [v1, v2], v3],
+ results=[[t0, t1], t2, t3])
+
+ # Now test with optional omitted.
+ # CHECK: "custom.test_op"(%[[V0]])
+ # CHECK-SAME: operand_segment_sizes = dense<[1, 0, 0]>
+ # CHECK-SAME: result_segment_sizes = dense<[0, 0, 1]>
+ # CHECK-SAME: (i32) -> i64
+ op = TestOp._ods_build_default(
+ operands=[v0, None, None],
+ results=[None, None, t3])
+ print(m)
+
+ # And verify that errors are raised for None in a required operand.
+ try:
+ op = TestOp._ods_build_default(
+ operands=[None, None, None],
+ results=[None, None, t3])
+ except ValueError as e:
+ # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
+ print(f"OPERAND_CAST_ERROR:{e}")
+
+ # And verify that errors are raised for None in a required result.
+ try:
+ op = TestOp._ods_build_default(
+ operands=[v0, None, None],
+ results=[None, None, None])
+ except ValueError as e:
+ # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
+ print(f"RESULT_CAST_ERROR:{e}")
+
+ # Variadic lists with None elements should reject.
+ try:
+ op = TestOp._ods_build_default(
+ operands=[v0, [None], None],
+ results=[None, None, t3])
+ except ValueError as e:
+ # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
+ print(f"OPERAND_LIST_CAST_ERROR:{e}")
+ try:
+ op = TestOp._ods_build_default(
+ operands=[v0, None, None],
+ results=[[None], None, t3])
+ except ValueError as e:
+ # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
+ print(f"RESULT_LIST_CAST_ERROR:{e}")
+
+run(testOdsBuildDefaultSizedVariadic)
+
+
+def testOdsBuildDefaultCastError():
+
+ class TestOp(OpView):
+ OPERATION_NAME = "custom.test_op"
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint.at_block_terminator(m.body):
+ v0 = add_dummy_value()
+ v1 = add_dummy_value()
+ t0 = IntegerType.get_signless(8)
+ t1 = IntegerType.get_signless(16)
+ try:
+ op = TestOp._ods_build_default(
+ operands=[None, v1],
+ results=[t0, t1])
+ except ValueError as e:
+ # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
+ print(f"ERROR: {e}")
+ try:
+ op = TestOp._ods_build_default(
+ operands=[v0, v1],
+ results=[t0, None])
+ except ValueError as e:
+ # CHECK: Result 1 of operation "custom.test_op" must be a Type
+ print(f"ERROR: {e}")
+
+run(testOdsBuildDefaultCastError)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 722cf9fb7e40..235cb4a1fa59 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -17,23 +17,18 @@ class TestOp<string mnemonic, list<OpTrait> traits = []> :
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedOperandsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
+// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,]
def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
[AttrSizedOperandSegments]> {
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
- // CHECK: operand_segment_sizes_ods = _ods_array.array('L')
- // CHECK: operands += [*variadic1]
- // CHECK: operand_segment_sizes_ods.append(len(variadic1))
+ // CHECK: operands.append(variadic1)
// CHECK: operands.append(non_variadic)
- // CHECK: operand_segment_sizes_ods.append(1)
// CHECK: if variadic2 is not None: operands.append(variadic2)
- // CHECK: operand_segment_sizes_ods.append(0 if variadic2 is None else 1)
- // CHECK: attributes["operand_segment_sizes"] = _ods_ir.DenseElementsAttr.get(operand_segment_sizes_ods,
- // CHECK: context=_ods_get_default_loc_context(loc))
- // CHECK: super().__init__(_ods_ir.Operation.create(
- // CHECK: "test.attr_sized_operands", attributes=attributes, operands=operands, results=results,
+ // CHECK: super().__init__(self._ods_build_default(
+ // CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@@ -63,23 +58,18 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
+// CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,]
def AttrSizedResultsOp : TestOp<"attr_sized_results",
[AttrSizedResultSegments]> {
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
- // CHECK: result_segment_sizes_ods = _ods_array.array('L')
// CHECK: if variadic1 is not None: results.append(variadic1)
- // CHECK: result_segment_sizes_ods.append(0 if variadic1 is None else 1)
// CHECK: results.append(non_variadic)
- // CHECK: result_segment_sizes_ods.append(1) # non_variadic
// CHECK: if variadic2 is not None: results.append(variadic2)
- // CHECK: result_segment_sizes_ods.append(0 if variadic2 is None else 1)
- // CHECK: attributes["result_segment_sizes"] = _ods_ir.DenseElementsAttr.get(result_segment_sizes_ods,
- // CHECK: context=_ods_get_default_loc_context(loc))
- // CHECK: super().__init__(_ods_ir.Operation.create(
- // CHECK: "test.attr_sized_results", attributes=attributes, operands=operands, results=results,
+ // CHECK: super().__init__(self._ods_build_default(
+ // CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@@ -110,6 +100,8 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op"
+// CHECK-NOT: _ODS_OPERAND_SEGMENTS
+// CHECK-NOT: _ODS_RESULT_SEGMENTS
def AttributedOp : TestOp<"attributed_op"> {
// CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None):
// CHECK: operands = []
@@ -120,8 +112,8 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: attributes["in"] = in_
- // CHECK: super().__init__(_ods_ir.Operation.create(
- // CHECK: "test.attributed_op", attributes=attributes, operands=operands, results=results,
+ // CHECK: super().__init__(self._ods_build_default(
+ // CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@@ -148,6 +140,8 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
+// CHECK-NOT: _ODS_OPERAND_SEGMENTS
+// CHECK-NOT: _ODS_RESULT_SEGMENTS
def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None):
// CHECK: operands = []
@@ -158,8 +152,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = is_
- // CHECK: super().__init__(_ods_ir.Operation.create(
- // CHECK: "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results,
+ // CHECK: super().__init__(self._ods_build_default(
+ // CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@@ -183,8 +177,8 @@ def EmptyOp : TestOp<"empty">;
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
- // CHECK: super().__init__(_ods_ir.Operation.create(
- // CHECK: "test.empty", attributes=attributes, operands=operands, results=results,
+ // CHECK: super().__init__(self._ods_build_default(
+ // CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -201,8 +195,8 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: operands.append(_gen_arg_0)
// CHECK: operands.append(f32)
// CHECK: operands.append(_gen_arg_2)
- // CHECK: super().__init__(_ods_ir.Operation.create(
- // CHECK: "test.missing_names", attributes=attributes, operands=operands, results=results,
+ // CHECK: super().__init__(self._ods_build_default(
+ // CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@@ -223,15 +217,17 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
+// CHECK-NOT: _ODS_OPERAND_SEGMENTS
+// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: def __init__(self, non_variadic, variadic, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operands.append(non_variadic)
- // CHECK: operands += [*variadic]
- // CHECK: super().__init__(_ods_ir.Operation.create(
- // CHECK: "test.one_variadic_operand", attributes=attributes, operands=operands, results=results,
+ // CHECK: operands.extend(variadic)
+ // CHECK: super().__init__(self._ods_build_default(
+ // CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@@ -248,15 +244,17 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
+// CHECK-NOT: _ODS_OPERAND_SEGMENTS
+// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneVariadicResultOp : TestOp<"one_variadic_result"> {
// CHECK: def __init__(self, variadic, non_variadic, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
- // CHECK: results += [*variadic]
+ // CHECK: results.extend(variadic)
// CHECK: results.append(non_variadic)
- // CHECK: super().__init__(_ods_ir.Operation.create(
- // CHECK: "test.one_variadic_result", attributes=attributes, operands=operands, results=results,
+ // CHECK: super().__init__(self._ods_build_default(
+ // CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@@ -280,8 +278,8 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operands.append(in_)
- // CHECK: super().__init__(_ods_ir.Operation.create(
- // CHECK: "test.python_keyword", attributes=attributes, operands=operands, results=results,
+ // CHECK: super().__init__(self._ods_build_default(
+ // CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@@ -348,8 +346,8 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: results.append(f64)
// CHECK: operands.append(i32)
// CHECK: operands.append(f32)
- // CHECK: super().__init__(_ods_ir.Operation.create(
- // CHECK: "test.simple", attributes=attributes, operands=operands, results=results,
+ // CHECK: super().__init__(self._ods_build_default(
+ // CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 16bf6d1dc03f..658ad75eea28 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -26,7 +26,6 @@ using namespace mlir::tblgen;
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
-import array as _ods_array
from . import _cext as _ods_cext
from . import _segmented_accessor as _ods_segmented_accessor, _equally_sized_accessor as _ods_equally_sized_accessor, _get_default_loc_context as _ods_get_default_loc_context
_ods_ir = _ods_cext.ir
@@ -51,6 +50,25 @@ class {0}(_ods_ir.OpView):
OPERATION_NAME = "{1}"
)Py";
+/// Template for class level declarations of operand and result
+/// segment specs.
+/// {0} is either "OPERAND" or "RESULT"
+/// {1} is the segment spec
+/// Each segment spec is either None (default) or an array of integers
+/// where:
+/// 1 = single element (expect non sequence operand/result)
+/// -1 = operand/result is a sequence corresponding to a variadic
+constexpr const char *opClassSizedSegmentsTemplate = R"Py(
+ _ODS_{0}_SEGMENTS = {1}
+)Py";
+
+/// Template for class level declarations of the _ODS_REGIONS spec:
+/// {0} is the minimum number of regions
+/// {1} is the Python bool literal for hasNoVariadicRegions
+constexpr const char *opClassRegionSpecTemplate = R"Py(
+ _ODS_REGIONS = ({0}, {1})
+)Py";
+
/// Template for single-element accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
@@ -446,18 +464,17 @@ static void emitAttributeAccessors(const Operator &op,
}
/// Template for the default auto-generated builder.
-/// {0} is the operation name;
-/// {1} is a comma-separated list of builder arguments, including the trailing
+/// {0} is a comma-separated list of builder arguments, including the trailing
/// `loc` and `ip`;
-/// {2} is the code populating `operands`, `results` and `attributes` fields.
+/// {1} is the code populating `operands`, `results` and `attributes` fields.
constexpr const char *initTemplate = R"Py(
- def __init__(self, {1}):
+ def __init__(self, {0}):
operands = []
results = []
attributes = {{}
- {2}
- super().__init__(_ods_ir.Operation.create(
- "{0}", attributes=attributes, operands=operands, results=results,
+ {1}
+ super().__init__(self._ods_build_default(
+ attributes=attributes, operands=operands, results=results,
loc=loc, ip=ip))
)Py";
@@ -472,37 +489,10 @@ constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
constexpr const char *optionalAppendTemplate =
"if {1} is not None: {0}s.append({1})";
-/// Template for appending a variadic element to the operand/result list.
-/// {0} is either 'operand' or 'result';
-/// {1} is the field name.
-constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]";
-
-/// Template for setting up the segment sizes buffer.
-constexpr const char *segmentDeclarationTemplate =
- "{0}_segment_sizes_ods = _ods_array.array('L')";
-
-/// Template for attaching segment sizes to the attribute list.
-constexpr const char *segmentAttributeTemplate =
- R"Py(attributes["{0}_segment_sizes"] = _ods_ir.DenseElementsAttr.get({0}_segment_sizes_ods,
- context=_ods_get_default_loc_context(loc)))Py";
-
-/// Template for appending the unit size to the segment sizes.
+/// Template for appending a a list of elements to the operand/result list.
/// {0} is either 'operand' or 'result';
/// {1} is the field name.
-constexpr const char *singleElementSegmentTemplate =
- "{0}_segment_sizes_ods.append(1) # {1}";
-
-/// Template for appending 0/1 for an optional element to the segment sizes.
-/// {0} is either 'operand' or 'result';
-/// {1} is the field name.
-constexpr const char *optionalSegmentTemplate =
- "{0}_segment_sizes_ods.append(0 if {1} is None else 1)";
-
-/// Template for appending the length of a variadic group to the segment sizes.
-/// {0} is either 'operand' or 'result';
-/// {1} is the field name.
-constexpr const char *variadicSegmentTemplate =
- "{0}_segment_sizes_ods.append(len({1}))";
+constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})";
/// Template for setting an attribute in the operation builder.
/// {0} is the attribute name;
@@ -584,11 +574,7 @@ static void populateBuilderLines(
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
- // The segment sizes buffer only has to be populated if there attr-sized
- // segments trait is present.
- bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
- if (includeSegments)
- builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind));
+ bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
// For each element, find or generate a name.
for (int i = 0, e = getNumElements(op); i < e; ++i) {
@@ -596,28 +582,28 @@ static void populateBuilderLines(
std::string name = names[i];
// Choose the formatting string based on the element kind.
- llvm::StringRef formatString, segmentFormatString;
+ llvm::StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleElementAppendTemplate;
- segmentFormatString = singleElementSegmentTemplate;
} else if (element.isOptional()) {
formatString = optionalAppendTemplate;
- segmentFormatString = optionalSegmentTemplate;
} else {
assert(element.isVariadic() && "unhandled element group type");
- formatString = variadicAppendTemplate;
- segmentFormatString = variadicSegmentTemplate;
+ // If emitting with sizedSegments, then we add the actual list typed
+ // element using the singleElementAppendTemplate. Otherwise, we extend
+ // the actual operands.
+ if (sizedSegments) {
+ // Append the list as is.
+ formatString = singleElementAppendTemplate;
+ } else {
+ // Append the list elements.
+ formatString = multiElementAppendTemplate;
+ }
}
// Add the lines.
builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
- if (includeSegments)
- builderLines.push_back(
- llvm::formatv(segmentFormatString.data(), kind, name));
}
-
- if (includeSegments)
- builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind));
}
/// Emits a default builder constructing an operation from the list of its
@@ -645,8 +631,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
builderArgs.push_back("loc=None");
builderArgs.push_back("ip=None");
- os << llvm::formatv(initTemplate, op.getOperationName(),
- llvm::join(builderArgs, ", "),
+ os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),
llvm::join(builderLines, "\n "));
}
@@ -659,12 +644,52 @@ static void constructAttributeMapping(const llvm::RecordKeeper &records,
}
}
+static void emitSegmentSpec(
+ const Operator &op, const char *kind,
+ llvm::function_ref<int(const Operator &)> getNumElements,
+ llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
+ getElement,
+ raw_ostream &os) {
+ std::string segmentSpec("[");
+ for (int i = 0, e = getNumElements(op); i < e; ++i) {
+ const NamedTypeConstraint &element = getElement(op, i);
+ if (element.isVariableLength()) {
+ segmentSpec.append("-1,");
+ } else if (element.isOptional()) {
+ segmentSpec.append("0,");
+ } else {
+ segmentSpec.append("1,");
+ }
+ }
+ segmentSpec.append("]");
+
+ os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
+}
+
+static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
+ // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
+ // Note that the base OpView class defines this as (0, True).
+ unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
+ os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
+ op.hasNoVariadicRegions() ? "True" : "False");
+}
+
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op,
const AttributeClasses &attributeClasses,
raw_ostream &os) {
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
op.getOperationName());
+
+ // Sized segments.
+ if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
+ emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
+ }
+ if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
+ emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
+ }
+
+ emitRegionAttributes(op, os);
emitDefaultOpBuilder(op, os);
emitOperandAccessors(op, os);
emitAttributeAccessors(op, attributeClasses, os);
More information about the llvm-branch-commits
mailing list