[Mlir-commits] [mlir] 8e6c55c - [mlir][python] Extend C/Python API to be usable for CFG construction.

Stella Laurenzo llvmlistbot at llvm.org
Mon Aug 30 08:28:41 PDT 2021


Author: Stella Laurenzo
Date: 2021-08-30T08:28:00-07:00
New Revision: 8e6c55c92c807ad77106298d9b7eaf453407d009

URL: https://github.com/llvm/llvm-project/commit/8e6c55c92c807ad77106298d9b7eaf453407d009
DIFF: https://github.com/llvm/llvm-project/commit/8e6c55c92c807ad77106298d9b7eaf453407d009.diff

LOG: [mlir][python] Extend C/Python API to be usable for CFG construction.

* It is pretty clear that no one has tried this yet since it was both incomplete and broken.
* Fixes a symbol hiding issues keeping even the generic builder from constructing an operation with successors.
* Adds ODS support for successors.
* Adds CAPI `mlirBlockGetParentRegion`, `mlirRegionEqual` + tests (and missing test for `mlirBlockGetParentOperation`).
* Adds Python property: `Block.region`.
* Adds Python methods: `Block.create_before` and `Block.create_after`.
* Adds Python property: `InsertionPoint.block`.
* Adds new blocks.py test to verify a plausible CFG construction case.

Differential Revision: https://reviews.llvm.org/D108898

Added: 
    mlir/test/python/ir/blocks.py

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/ir.c
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 6924fa88d3a91..ebc3ada600fde 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -447,6 +447,10 @@ MLIR_CAPI_EXPORTED void mlirRegionDestroy(MlirRegion region);
 /// Checks whether a region is null.
 static inline bool mlirRegionIsNull(MlirRegion region) { return !region.ptr; }
 
+/// Checks whether two region handles point to the same region. This does not
+/// perform deep comparison.
+MLIR_CAPI_EXPORTED bool mlirRegionEqual(MlirRegion region, MlirRegion other);
+
 /// Gets the first block in the region.
 MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region);
 
@@ -496,6 +500,9 @@ MLIR_CAPI_EXPORTED bool mlirBlockEqual(MlirBlock block, MlirBlock other);
 /// Returns the closest surrounding operation that contains this block.
 MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock);
 
+/// Returns the region that contains this block.
+MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block);
+
 /// Returns the block immediately following the given block in its parent
 /// region.
 MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block);

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d6305e7f49ec5..7add4eb7b0379 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -969,7 +969,6 @@ py::object PyOperation::create(
   }
   // Unpack/validate successors.
   if (successors) {
-    llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
     mlirSuccessors.reserve(successors->size());
     for (auto *successor : *successors) {
       // TODO: Verify successor originate from the same context.
@@ -2206,6 +2205,13 @@ void mlir::python::populateIRCore(py::module &m) {
             return self.getParentOperation()->createOpView();
           },
           "Returns the owning operation of this block.")
+      .def_property_readonly(
+          "region",
+          [](PyBlock &self) {
+            MlirRegion region = mlirBlockGetParentRegion(self.get());
+            return PyRegion(self.getParentOperation(), region);
+          },
+          "Returns the owning region of this block.")
       .def_property_readonly(
           "arguments",
           [](PyBlock &self) {
@@ -2218,6 +2224,40 @@ void mlir::python::populateIRCore(py::module &m) {
             return PyOperationList(self.getParentOperation(), self.get());
           },
           "Returns a forward-optimized sequence of operations.")
+      .def(
+          "create_before",
+          [](PyBlock &self, py::args pyArgTypes) {
+            self.checkValid();
+            llvm::SmallVector<MlirType, 4> argTypes;
+            argTypes.reserve(pyArgTypes.size());
+            for (auto &pyArg : pyArgTypes) {
+              argTypes.push_back(pyArg.cast<PyType &>());
+            }
+
+            MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
+            MlirRegion region = mlirBlockGetParentRegion(self.get());
+            mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
+            return PyBlock(self.getParentOperation(), block);
+          },
+          "Creates and returns a new Block before this block "
+          "(with given argument types).")
+      .def(
+          "create_after",
+          [](PyBlock &self, py::args pyArgTypes) {
+            self.checkValid();
+            llvm::SmallVector<MlirType, 4> argTypes;
+            argTypes.reserve(pyArgTypes.size());
+            for (auto &pyArg : pyArgTypes) {
+              argTypes.push_back(pyArg.cast<PyType &>());
+            }
+
+            MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
+            MlirRegion region = mlirBlockGetParentRegion(self.get());
+            mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
+            return PyBlock(self.getParentOperation(), block);
+          },
+          "Creates and returns a new Block after this block "
+          "(with given argument types).")
       .def(
           "__iter__",
           [](PyBlock &self) {
@@ -2270,7 +2310,10 @@ void mlir::python::populateIRCore(py::module &m) {
       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
                   py::arg("block"), "Inserts before the block terminator.")
       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
-           "Inserts an operation.");
+           "Inserts an operation.")
+      .def_property_readonly(
+          "block", [](PyInsertionPoint &self) { return self.getBlock(); },
+          "Returns the block that this InsertionPoint points to.");
 
   //----------------------------------------------------------------------------
   // Mapping of PyAttribute.

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 2721efde31f89..68037f0afe9c3 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -427,6 +427,10 @@ bool mlirOperationVerify(MlirOperation op) {
 
 MlirRegion mlirRegionCreate() { return wrap(new Region); }
 
+bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
+  return unwrap(region) == unwrap(other);
+}
+
 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
   Region *cppRegion = unwrap(region);
   if (cppRegion->empty())
@@ -492,6 +496,10 @@ MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
   return wrap(unwrap(block)->getParentOp());
 }
 
+MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
+  return wrap(unwrap(block)->getParent());
+}
+
 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
   return wrap(unwrap(block)->getNextNode());
 }

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 296b68b0eb464..1acd4b22bbf48 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -323,13 +323,20 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
 
   // Verify that parent operation and block report correctly.
+  // CHECK: Parent operation eq: 1
   fprintf(stderr, "Parent operation eq: %d\n",
           mlirOperationEqual(mlirOperationGetParentOperation(operation),
                              parentOperation));
+  // CHECK: Block eq: 1
   fprintf(stderr, "Block eq: %d\n",
           mlirBlockEqual(mlirOperationGetBlock(operation), block));
-  // CHECK: Parent operation eq: 1
-  // CHECK: Block eq: 1
+  // CHECK: Block parent operation eq: 1
+  fprintf(
+      stderr, "Block parent operation eq: %d\n",
+      mlirOperationEqual(mlirBlockGetParentOperation(block), parentOperation));
+  // CHECK: Block parent region eq: 1
+  fprintf(stderr, "Block parent region eq: %d\n",
+          mlirRegionEqual(mlirBlockGetParentRegion(block), region));
 
   // In the module we created, the first operation of the first function is
   // an "memref.dim", which has an attribute and a single result that we can
@@ -441,7 +448,8 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
       operation, mlirStringRefCreateFromCString("elts"),
       mlirDenseElementsAttrInt32Get(
           mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
-                                  mlirAttributeGetNull()), 4, eltsData));
+                                  mlirAttributeGetNull()),
+          4, eltsData));
   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
   mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
   mlirOpPrintingFlagsPrintGenericOpForm(flags);
@@ -909,25 +917,25 @@ int printBuiltinAttributes(MlirContext ctx) {
       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
       2, ints8);
   MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
-      mlirRankedTensorTypeGet(2, shape,
-                              mlirIntegerTypeUnsignedGet(ctx, 32), encoding),
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
+                              encoding),
       2, uints32);
   MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
       2, ints32);
   MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
-      mlirRankedTensorTypeGet(2, shape,
-                              mlirIntegerTypeUnsignedGet(ctx, 64), encoding),
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
+                              encoding),
       2, uints64);
   MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
       2, ints64);
   MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
-      mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
-      2, floats);
+      mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 2,
+      floats);
   MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
-      mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding),
-      2, doubles);
+      mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2,
+      doubles);
 
   if (!mlirAttributeIsADenseElements(boolElements) ||
       !mlirAttributeIsADenseElements(uint8Elements) ||
@@ -1084,8 +1092,8 @@ int printBuiltinAttributes(MlirContext ctx) {
       mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), encoding),
       2, indices);
   MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
-      mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding),
-      2, floats);
+      mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding), 2,
+      floats);
   MlirAttribute sparseAttr = mlirSparseElementsAttribute(
       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
       indicesAttr, valuesAttr);
@@ -1635,11 +1643,12 @@ int testClone() {
   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("std"));
   MlirLocation loc = mlirLocationUnknownGet(ctx);
   MlirType indexType = mlirIndexTypeGet(ctx);
-  MlirStringRef valueStringRef =  mlirStringRefCreateFromCString("value");
+  MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
 
   MlirAttribute indexZeroLiteral =
       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
-  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
+  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
   MlirOperationState constZeroState = mlirOperationStateGet(
       mlirStringRefCreateFromCString("std.constant"), loc);
   mlirOperationStateAddResults(&constZeroState, 1, &indexType);

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 95c58d5c3c25b..572c657336686 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -27,9 +27,10 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK:   operands.append(variadic1)
   // CHECK:   operands.append(non_variadic)
   // CHECK:   if variadic2 is not None: operands.append(variadic2)
+  // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
@@ -68,9 +69,10 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
   // CHECK:   if variadic1 is not None: results.append(variadic1)
   // CHECK:   results.append(non_variadic)
   // CHECK:   if variadic2 is not None: results.append(variadic2)
+  // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
@@ -112,9 +114,10 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:   if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
   // CHECK:   attributes["in"] = in_
+  // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def i32attr(self):
@@ -152,9 +155,10 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:   if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
   // CHECK:   if is_ is not None: attributes["is"] = is_
+  // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def in_(self):
@@ -177,9 +181,10 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -195,9 +200,10 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK:   operands.append(_gen_arg_0)
   // CHECK:   operands.append(f32)
   // CHECK:   operands.append(_gen_arg_2)
+  // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def f32(self):
@@ -226,9 +232,10 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   // CHECK:   attributes = {}
   // CHECK:   operands.append(non_variadic)
   // CHECK:   operands.extend(variadic)
+  // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
@@ -253,9 +260,10 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
   // CHECK:   attributes = {}
   // CHECK:   results.extend(variadic)
   // CHECK:   results.append(non_variadic)
+  // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def variadic(self):
@@ -278,9 +286,10 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   operands.append(in_)
+  // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def in_(self):
@@ -346,9 +355,10 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:   results.append(f64)
   // CHECK:   operands.append(i32)
   // CHECK:   operands.append(f32)
+  // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def i32(self):
@@ -368,3 +378,15 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:   return self.operation.results[1]
   let results = (outs I64:$i64, F64:$f64);
 }
+
+// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.with_successors"
+def WithSuccessorsOp : TestOp<"with_successors"> {
+  // CHECK-NOT:  _ods_successors = None
+  // CHECK:      _ods_successors = []
+  // CHECK-NEXT: _ods_successors.append(successor)
+  // CHECK-NEXT: _ods_successors.extend(successors)
+  let successors = (successor AnySuccessor:$successor,
+                              VariadicSuccessor<AnySuccessor>:$successors);
+}

diff  --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
new file mode 100644
index 0000000000000..81dccdd1f52d5
--- /dev/null
+++ b/mlir/test/python/ir/blocks.py
@@ -0,0 +1,53 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+import io
+import itertools
+from mlir.ir import *
+from mlir.dialects import builtin
+# Note: std dialect needed for terminators.
+from mlir.dialects import std
+
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert Context._get_live_count() == 0
+  return f
+
+
+# CHECK-LABEL: TEST: testBlockCreation
+# CHECK: func @test(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16)
+# CHECK:   br ^bb1(%[[ARG1]] : i16)
+# CHECK: ^bb1(%[[PHI0:.*]]: i16):
+# CHECK:   br ^bb2(%[[ARG0]] : i32)
+# CHECK: ^bb2(%[[PHI1:.*]]: i32):
+# CHECK:   return
+ at run
+def testBlockCreation():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    with InsertionPoint(module.body):
+      f_type = FunctionType.get(
+          [IntegerType.get_signless(32),
+           IntegerType.get_signless(16)], [])
+      f_op = builtin.FuncOp("test", f_type)
+      entry_block = f_op.add_entry_block()
+      i32_arg, i16_arg = entry_block.arguments
+      successor_block = entry_block.create_after(i32_arg.type)
+      with InsertionPoint(successor_block) as successor_ip:
+        assert successor_ip.block == successor_block
+        std.ReturnOp([])
+      middle_block = successor_block.create_before(i16_arg.type)
+
+      with InsertionPoint(entry_block) as entry_ip:
+        assert entry_ip.block == entry_block
+        std.BranchOp([i16_arg], dest=middle_block)
+
+      with InsertionPoint(middle_block) as middle_ip:
+        assert middle_ip.block == middle_block
+        std.BranchOp([i32_arg], dest=successor_block)
+    print(module.operation)
+    # Ensure region back references are coherent.
+    assert entry_block.region == middle_block.region == successor_block.region

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index de84b8a24e289..742cad748ca46 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -475,7 +475,8 @@ static void emitAttributeAccessors(const Operator &op,
 /// Template for the default auto-generated builder.
 ///   {0} is a comma-separated list of builder arguments, including the trailing
 ///       `loc` and `ip`;
-///   {1} is the code populating `operands`, `results` and `attributes` fields.
+///   {1} is the code populating `operands`, `results` and `attributes`,
+///       `successors` fields.
 constexpr const char *initTemplate = R"Py(
   def __init__(self, {0}):
     operands = []
@@ -484,7 +485,7 @@ constexpr const char *initTemplate = R"Py(
     {1}
     super().__init__(self.build_generic(
       attributes=attributes, results=results, operands=operands,
-      loc=loc, ip=ip))
+      successors=_ods_successors, loc=loc, ip=ip))
 )Py";
 
 /// Template for appending a single element to the operand/result list.
@@ -518,6 +519,16 @@ constexpr const char *initUnitAttributeTemplate =
     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
       _ods_get_default_loc_context(loc)))Py";
 
+/// Template to initialize the successors list in the builder if there are any
+/// successors.
+///   {0} is the value to initialize the successors list to.
+constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
+
+/// Template to append or extend the list of successors in the builder.
+///   {0} is the list method ('append' or 'extend');
+///   {1} is the value to add.
+constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
+
 /// Populates `builderArgs` with the Python-compatible names of builder function
 /// arguments, first the results, then the intermixed attributes and operands in
 /// the same order as they appear in the `arguments` field of the op definition.
@@ -526,7 +537,8 @@ constexpr const char *initUnitAttributeTemplate =
 static void
 populateBuilderArgs(const Operator &op,
                     llvm::SmallVectorImpl<std::string> &builderArgs,
-                    llvm::SmallVectorImpl<std::string> &operandNames) {
+                    llvm::SmallVectorImpl<std::string> &operandNames,
+                    llvm::SmallVectorImpl<std::string> &successorArgNames) {
   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
     std::string name = op.getResultName(i).str();
     if (name.empty()) {
@@ -550,6 +562,16 @@ populateBuilderArgs(const Operator &op,
     if (!op.getArg(i).is<NamedAttribute *>())
       operandNames.push_back(name);
   }
+
+  for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
+    NamedSuccessor successor = op.getSuccessor(i);
+    std::string name = std::string(successor.name);
+    if (name.empty())
+      name = llvm::formatv("_gen_successor_{0}", i);
+    name = sanitizeName(name);
+    builderArgs.push_back(name);
+    successorArgNames.push_back(name);
+  }
 }
 
 /// Populates `builderLines` with additional lines that are required in the
@@ -581,6 +603,27 @@ populateBuilderLinesAttr(const Operator &op,
   }
 }
 
+/// Populates `builderLines` with additional lines that are required in the
+/// builder to set up successors. successorArgNames is expected to correspond
+/// to the Python argument name for each successor on the op.
+static void populateBuilderLinesSuccessors(
+    const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
+    llvm::SmallVectorImpl<std::string> &builderLines) {
+  if (successorArgNames.empty()) {
+    builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
+    return;
+  }
+
+  builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
+  for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
+    auto &argName = successorArgNames[i];
+    const NamedSuccessor &successor = op.getSuccessor(i);
+    builderLines.push_back(
+        llvm::formatv(addSuccessorTemplate,
+                      successor.isVariadic() ? "extend" : "append", argName));
+  }
+}
+
 /// Populates `builderLines` with additional lines that are required in the
 /// builder. `kind` must be either "operand" or "result". `names` contains the
 /// names of init arguments that correspond to the elements.
@@ -629,12 +672,14 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
   if (op.skipDefaultBuilders())
     return;
 
-  llvm::SmallVector<std::string, 8> builderArgs;
-  llvm::SmallVector<std::string, 8> builderLines;
-  llvm::SmallVector<std::string, 4> operandArgNames;
+  llvm::SmallVector<std::string> builderArgs;
+  llvm::SmallVector<std::string> builderLines;
+  llvm::SmallVector<std::string> operandArgNames;
+  llvm::SmallVector<std::string> successorArgNames;
   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
-                      op.getNumNativeAttributes());
-  populateBuilderArgs(op, builderArgs, operandArgNames);
+                      op.getNumNativeAttributes() + op.getNumSuccessors());
+  populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
+
   populateBuilderLines(
       op, "result",
       llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
@@ -644,6 +689,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
   populateBuilderLinesAttr(
       op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
       builderLines);
+  populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
 
   builderArgs.push_back("*");
   builderArgs.push_back("loc=None");


        


More information about the Mlir-commits mailing list