[Mlir-commits] [mlir] 8e6c55c - [mlir][python] Extend C/Python API to be usable for CFG construction.
Stella Laurenzo
llvmlistbot at llvm.org
Mon Aug 30 08:28:41 PDT 2021
Author: Stella Laurenzo
Date: 2021-08-30T08:28:00-07:00
New Revision: 8e6c55c92c807ad77106298d9b7eaf453407d009
URL: https://github.com/llvm/llvm-project/commit/8e6c55c92c807ad77106298d9b7eaf453407d009
DIFF: https://github.com/llvm/llvm-project/commit/8e6c55c92c807ad77106298d9b7eaf453407d009.diff
LOG: [mlir][python] Extend C/Python API to be usable for CFG construction.
* It is pretty clear that no one has tried this yet since it was both incomplete and broken.
* Fixes a symbol hiding issues keeping even the generic builder from constructing an operation with successors.
* Adds ODS support for successors.
* Adds CAPI `mlirBlockGetParentRegion`, `mlirRegionEqual` + tests (and missing test for `mlirBlockGetParentOperation`).
* Adds Python property: `Block.region`.
* Adds Python methods: `Block.create_before` and `Block.create_after`.
* Adds Python property: `InsertionPoint.block`.
* Adds new blocks.py test to verify a plausible CFG construction case.
Differential Revision: https://reviews.llvm.org/D108898
Added:
mlir/test/python/ir/blocks.py
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 6924fa88d3a91..ebc3ada600fde 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -447,6 +447,10 @@ MLIR_CAPI_EXPORTED void mlirRegionDestroy(MlirRegion region);
/// Checks whether a region is null.
static inline bool mlirRegionIsNull(MlirRegion region) { return !region.ptr; }
+/// Checks whether two region handles point to the same region. This does not
+/// perform deep comparison.
+MLIR_CAPI_EXPORTED bool mlirRegionEqual(MlirRegion region, MlirRegion other);
+
/// Gets the first block in the region.
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region);
@@ -496,6 +500,9 @@ MLIR_CAPI_EXPORTED bool mlirBlockEqual(MlirBlock block, MlirBlock other);
/// Returns the closest surrounding operation that contains this block.
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock);
+/// Returns the region that contains this block.
+MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block);
+
/// Returns the block immediately following the given block in its parent
/// region.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d6305e7f49ec5..7add4eb7b0379 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -969,7 +969,6 @@ py::object PyOperation::create(
}
// Unpack/validate successors.
if (successors) {
- llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
mlirSuccessors.reserve(successors->size());
for (auto *successor : *successors) {
// TODO: Verify successor originate from the same context.
@@ -2206,6 +2205,13 @@ void mlir::python::populateIRCore(py::module &m) {
return self.getParentOperation()->createOpView();
},
"Returns the owning operation of this block.")
+ .def_property_readonly(
+ "region",
+ [](PyBlock &self) {
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ return PyRegion(self.getParentOperation(), region);
+ },
+ "Returns the owning region of this block.")
.def_property_readonly(
"arguments",
[](PyBlock &self) {
@@ -2218,6 +2224,40 @@ void mlir::python::populateIRCore(py::module &m) {
return PyOperationList(self.getParentOperation(), self.get());
},
"Returns a forward-optimized sequence of operations.")
+ .def(
+ "create_before",
+ [](PyBlock &self, py::args pyArgTypes) {
+ self.checkValid();
+ llvm::SmallVector<MlirType, 4> argTypes;
+ argTypes.reserve(pyArgTypes.size());
+ for (auto &pyArg : pyArgTypes) {
+ argTypes.push_back(pyArg.cast<PyType &>());
+ }
+
+ MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
+ return PyBlock(self.getParentOperation(), block);
+ },
+ "Creates and returns a new Block before this block "
+ "(with given argument types).")
+ .def(
+ "create_after",
+ [](PyBlock &self, py::args pyArgTypes) {
+ self.checkValid();
+ llvm::SmallVector<MlirType, 4> argTypes;
+ argTypes.reserve(pyArgTypes.size());
+ for (auto &pyArg : pyArgTypes) {
+ argTypes.push_back(pyArg.cast<PyType &>());
+ }
+
+ MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
+ return PyBlock(self.getParentOperation(), block);
+ },
+ "Creates and returns a new Block after this block "
+ "(with given argument types).")
.def(
"__iter__",
[](PyBlock &self) {
@@ -2270,7 +2310,10 @@ void mlir::python::populateIRCore(py::module &m) {
.def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
py::arg("block"), "Inserts before the block terminator.")
.def("insert", &PyInsertionPoint::insert, py::arg("operation"),
- "Inserts an operation.");
+ "Inserts an operation.")
+ .def_property_readonly(
+ "block", [](PyInsertionPoint &self) { return self.getBlock(); },
+ "Returns the block that this InsertionPoint points to.");
//----------------------------------------------------------------------------
// Mapping of PyAttribute.
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 2721efde31f89..68037f0afe9c3 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -427,6 +427,10 @@ bool mlirOperationVerify(MlirOperation op) {
MlirRegion mlirRegionCreate() { return wrap(new Region); }
+bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
+ return unwrap(region) == unwrap(other);
+}
+
MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
Region *cppRegion = unwrap(region);
if (cppRegion->empty())
@@ -492,6 +496,10 @@ MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
return wrap(unwrap(block)->getParentOp());
}
+MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
+ return wrap(unwrap(block)->getParent());
+}
+
MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
return wrap(unwrap(block)->getNextNode());
}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 296b68b0eb464..1acd4b22bbf48 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -323,13 +323,20 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
// Verify that parent operation and block report correctly.
+ // CHECK: Parent operation eq: 1
fprintf(stderr, "Parent operation eq: %d\n",
mlirOperationEqual(mlirOperationGetParentOperation(operation),
parentOperation));
+ // CHECK: Block eq: 1
fprintf(stderr, "Block eq: %d\n",
mlirBlockEqual(mlirOperationGetBlock(operation), block));
- // CHECK: Parent operation eq: 1
- // CHECK: Block eq: 1
+ // CHECK: Block parent operation eq: 1
+ fprintf(
+ stderr, "Block parent operation eq: %d\n",
+ mlirOperationEqual(mlirBlockGetParentOperation(block), parentOperation));
+ // CHECK: Block parent region eq: 1
+ fprintf(stderr, "Block parent region eq: %d\n",
+ mlirRegionEqual(mlirBlockGetParentRegion(block), region));
// In the module we created, the first operation of the first function is
// an "memref.dim", which has an attribute and a single result that we can
@@ -441,7 +448,8 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
operation, mlirStringRefCreateFromCString("elts"),
mlirDenseElementsAttrInt32Get(
mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
- mlirAttributeGetNull()), 4, eltsData));
+ mlirAttributeGetNull()),
+ 4, eltsData));
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
mlirOpPrintingFlagsPrintGenericOpForm(flags);
@@ -909,25 +917,25 @@ int printBuiltinAttributes(MlirContext ctx) {
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
2, ints8);
MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
- mlirRankedTensorTypeGet(2, shape,
- mlirIntegerTypeUnsignedGet(ctx, 32), encoding),
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
+ encoding),
2, uints32);
MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
2, ints32);
MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
- mlirRankedTensorTypeGet(2, shape,
- mlirIntegerTypeUnsignedGet(ctx, 64), encoding),
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
+ encoding),
2, uints64);
MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
2, ints64);
MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
- mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
- 2, floats);
+ mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 2,
+ floats);
MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
- mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding),
- 2, doubles);
+ mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2,
+ doubles);
if (!mlirAttributeIsADenseElements(boolElements) ||
!mlirAttributeIsADenseElements(uint8Elements) ||
@@ -1084,8 +1092,8 @@ int printBuiltinAttributes(MlirContext ctx) {
mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), encoding),
2, indices);
MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
- mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding),
- 2, floats);
+ mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding), 2,
+ floats);
MlirAttribute sparseAttr = mlirSparseElementsAttribute(
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
indicesAttr, valuesAttr);
@@ -1635,11 +1643,12 @@ int testClone() {
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("std"));
MlirLocation loc = mlirLocationUnknownGet(ctx);
MlirType indexType = mlirIndexTypeGet(ctx);
- MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
+ MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
MlirAttribute indexZeroLiteral =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
- MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
+ MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+ mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
MlirOperationState constZeroState = mlirOperationStateGet(
mlirStringRefCreateFromCString("std.constant"), loc);
mlirOperationStateAddResults(&constZeroState, 1, &indexType);
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 95c58d5c3c25b..572c657336686 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -27,9 +27,10 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: operands.append(variadic1)
// CHECK: operands.append(non_variadic)
// CHECK: if variadic2 is not None: operands.append(variadic2)
+ // CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
- // CHECK: loc=loc, ip=ip))
+ // CHECK: successors=_ods_successors, loc=loc, ip=ip))
// CHECK: @builtins.property
// CHECK: def variadic1(self):
@@ -68,9 +69,10 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: if variadic1 is not None: results.append(variadic1)
// CHECK: results.append(non_variadic)
// CHECK: if variadic2 is not None: results.append(variadic2)
+ // CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
- // CHECK: loc=loc, ip=ip))
+ // CHECK: successors=_ods_successors, loc=loc, ip=ip))
// CHECK: @builtins.property
// CHECK: def variadic1(self):
@@ -112,9 +114,10 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: attributes["in"] = in_
+ // CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
- // CHECK: loc=loc, ip=ip))
+ // CHECK: successors=_ods_successors, loc=loc, ip=ip))
// CHECK: @builtins.property
// CHECK: def i32attr(self):
@@ -152,9 +155,10 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = is_
+ // CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
- // CHECK: loc=loc, ip=ip))
+ // CHECK: successors=_ods_successors, loc=loc, ip=ip))
// CHECK: @builtins.property
// CHECK: def in_(self):
@@ -177,9 +181,10 @@ def EmptyOp : TestOp<"empty">;
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
+ // CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
- // CHECK: loc=loc, ip=ip))
+ // CHECK: successors=_ods_successors, loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -195,9 +200,10 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: operands.append(_gen_arg_0)
// CHECK: operands.append(f32)
// CHECK: operands.append(_gen_arg_2)
+ // CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
- // CHECK: loc=loc, ip=ip))
+ // CHECK: successors=_ods_successors, loc=loc, ip=ip))
// CHECK: @builtins.property
// CHECK: def f32(self):
@@ -226,9 +232,10 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: attributes = {}
// CHECK: operands.append(non_variadic)
// CHECK: operands.extend(variadic)
+ // CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
- // CHECK: loc=loc, ip=ip))
+ // CHECK: successors=_ods_successors, loc=loc, ip=ip))
// CHECK: @builtins.property
// CHECK: def non_variadic(self):
@@ -253,9 +260,10 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
// CHECK: attributes = {}
// CHECK: results.extend(variadic)
// CHECK: results.append(non_variadic)
+ // CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
- // CHECK: loc=loc, ip=ip))
+ // CHECK: successors=_ods_successors, loc=loc, ip=ip))
// CHECK: @builtins.property
// CHECK: def variadic(self):
@@ -278,9 +286,10 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operands.append(in_)
+ // CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
- // CHECK: loc=loc, ip=ip))
+ // CHECK: successors=_ods_successors, loc=loc, ip=ip))
// CHECK: @builtins.property
// CHECK: def in_(self):
@@ -346,9 +355,10 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: results.append(f64)
// CHECK: operands.append(i32)
// CHECK: operands.append(f32)
+ // CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
- // CHECK: loc=loc, ip=ip))
+ // CHECK: successors=_ods_successors, loc=loc, ip=ip))
// CHECK: @builtins.property
// CHECK: def i32(self):
@@ -368,3 +378,15 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: return self.operation.results[1]
let results = (outs I64:$i64, F64:$f64);
}
+
+// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.with_successors"
+def WithSuccessorsOp : TestOp<"with_successors"> {
+ // CHECK-NOT: _ods_successors = None
+ // CHECK: _ods_successors = []
+ // CHECK-NEXT: _ods_successors.append(successor)
+ // CHECK-NEXT: _ods_successors.extend(successors)
+ let successors = (successor AnySuccessor:$successor,
+ VariadicSuccessor<AnySuccessor>:$successors);
+}
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
new file mode 100644
index 0000000000000..81dccdd1f52d5
--- /dev/null
+++ b/mlir/test/python/ir/blocks.py
@@ -0,0 +1,53 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+import io
+import itertools
+from mlir.ir import *
+from mlir.dialects import builtin
+# Note: std dialect needed for terminators.
+from mlir.dialects import std
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
+
+
+# CHECK-LABEL: TEST: testBlockCreation
+# CHECK: func @test(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16)
+# CHECK: br ^bb1(%[[ARG1]] : i16)
+# CHECK: ^bb1(%[[PHI0:.*]]: i16):
+# CHECK: br ^bb2(%[[ARG0]] : i32)
+# CHECK: ^bb2(%[[PHI1:.*]]: i32):
+# CHECK: return
+ at run
+def testBlockCreation():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f_type = FunctionType.get(
+ [IntegerType.get_signless(32),
+ IntegerType.get_signless(16)], [])
+ f_op = builtin.FuncOp("test", f_type)
+ entry_block = f_op.add_entry_block()
+ i32_arg, i16_arg = entry_block.arguments
+ successor_block = entry_block.create_after(i32_arg.type)
+ with InsertionPoint(successor_block) as successor_ip:
+ assert successor_ip.block == successor_block
+ std.ReturnOp([])
+ middle_block = successor_block.create_before(i16_arg.type)
+
+ with InsertionPoint(entry_block) as entry_ip:
+ assert entry_ip.block == entry_block
+ std.BranchOp([i16_arg], dest=middle_block)
+
+ with InsertionPoint(middle_block) as middle_ip:
+ assert middle_ip.block == middle_block
+ std.BranchOp([i32_arg], dest=successor_block)
+ print(module.operation)
+ # Ensure region back references are coherent.
+ assert entry_block.region == middle_block.region == successor_block.region
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index de84b8a24e289..742cad748ca46 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -475,7 +475,8 @@ static void emitAttributeAccessors(const Operator &op,
/// Template for the default auto-generated builder.
/// {0} is a comma-separated list of builder arguments, including the trailing
/// `loc` and `ip`;
-/// {1} is the code populating `operands`, `results` and `attributes` fields.
+/// {1} is the code populating `operands`, `results` and `attributes`,
+/// `successors` fields.
constexpr const char *initTemplate = R"Py(
def __init__(self, {0}):
operands = []
@@ -484,7 +485,7 @@ constexpr const char *initTemplate = R"Py(
{1}
super().__init__(self.build_generic(
attributes=attributes, results=results, operands=operands,
- loc=loc, ip=ip))
+ successors=_ods_successors, loc=loc, ip=ip))
)Py";
/// Template for appending a single element to the operand/result list.
@@ -518,6 +519,16 @@ constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
_ods_get_default_loc_context(loc)))Py";
+/// Template to initialize the successors list in the builder if there are any
+/// successors.
+/// {0} is the value to initialize the successors list to.
+constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
+
+/// Template to append or extend the list of successors in the builder.
+/// {0} is the list method ('append' or 'extend');
+/// {1} is the value to add.
+constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
+
/// Populates `builderArgs` with the Python-compatible names of builder function
/// arguments, first the results, then the intermixed attributes and operands in
/// the same order as they appear in the `arguments` field of the op definition.
@@ -526,7 +537,8 @@ constexpr const char *initUnitAttributeTemplate =
static void
populateBuilderArgs(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs,
- llvm::SmallVectorImpl<std::string> &operandNames) {
+ llvm::SmallVectorImpl<std::string> &operandNames,
+ llvm::SmallVectorImpl<std::string> &successorArgNames) {
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
std::string name = op.getResultName(i).str();
if (name.empty()) {
@@ -550,6 +562,16 @@ populateBuilderArgs(const Operator &op,
if (!op.getArg(i).is<NamedAttribute *>())
operandNames.push_back(name);
}
+
+ for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
+ NamedSuccessor successor = op.getSuccessor(i);
+ std::string name = std::string(successor.name);
+ if (name.empty())
+ name = llvm::formatv("_gen_successor_{0}", i);
+ name = sanitizeName(name);
+ builderArgs.push_back(name);
+ successorArgNames.push_back(name);
+ }
}
/// Populates `builderLines` with additional lines that are required in the
@@ -581,6 +603,27 @@ populateBuilderLinesAttr(const Operator &op,
}
}
+/// Populates `builderLines` with additional lines that are required in the
+/// builder to set up successors. successorArgNames is expected to correspond
+/// to the Python argument name for each successor on the op.
+static void populateBuilderLinesSuccessors(
+ const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
+ llvm::SmallVectorImpl<std::string> &builderLines) {
+ if (successorArgNames.empty()) {
+ builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
+ return;
+ }
+
+ builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
+ for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
+ auto &argName = successorArgNames[i];
+ const NamedSuccessor &successor = op.getSuccessor(i);
+ builderLines.push_back(
+ llvm::formatv(addSuccessorTemplate,
+ successor.isVariadic() ? "extend" : "append", argName));
+ }
+}
+
/// Populates `builderLines` with additional lines that are required in the
/// builder. `kind` must be either "operand" or "result". `names` contains the
/// names of init arguments that correspond to the elements.
@@ -629,12 +672,14 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
if (op.skipDefaultBuilders())
return;
- llvm::SmallVector<std::string, 8> builderArgs;
- llvm::SmallVector<std::string, 8> builderLines;
- llvm::SmallVector<std::string, 4> operandArgNames;
+ llvm::SmallVector<std::string> builderArgs;
+ llvm::SmallVector<std::string> builderLines;
+ llvm::SmallVector<std::string> operandArgNames;
+ llvm::SmallVector<std::string> successorArgNames;
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
- op.getNumNativeAttributes());
- populateBuilderArgs(op, builderArgs, operandArgNames);
+ op.getNumNativeAttributes() + op.getNumSuccessors());
+ populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
+
populateBuilderLines(
op, "result",
llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
@@ -644,6 +689,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
populateBuilderLinesAttr(
op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
builderLines);
+ populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
builderArgs.push_back("*");
builderArgs.push_back("loc=None");
More information about the Mlir-commits
mailing list