[Mlir-commits] [mlir] 18fbd5f - [mlir][python] Better support for variadic regions in Python bindings

Alex Zinenko llvmlistbot at llvm.org
Thu Oct 14 04:15:18 PDT 2021


Author: Alex Zinenko
Date: 2021-10-14T13:15:13+02:00
New Revision: 18fbd5fe34f0f01a4f013a3864b6dc681ada58b1

URL: https://github.com/llvm/llvm-project/commit/18fbd5fe34f0f01a4f013a3864b6dc681ada58b1
DIFF: https://github.com/llvm/llvm-project/commit/18fbd5fe34f0f01a4f013a3864b6dc681ada58b1.diff

LOG: [mlir][python] Better support for variadic regions in Python bindings

Improve support for variadic regions in ODS-generated operation view classes.
In particular, make generated constructors take an extra argument that
specifies the number of variadic regions if the operation has them. Previously,
there was no mechanism to specify a non-zero number of variadic regions. Also
generate named accessors to regions.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D111783

Added: 
    

Modified: 
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index c3ee0c47aa05a..9355fc4aeb3ab 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -24,13 +24,14 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   regions = None
   // CHECK:   operands.append(_get_op_results_or_values(variadic1))
   // CHECK:   operands.append(_get_op_result_or_value(non_variadic))
   // CHECK:   if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2))
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
@@ -66,13 +67,14 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   regions = None
   // CHECK:   if variadic1 is not None: results.append(variadic1)
   // CHECK:   results.append(non_variadic)
   // CHECK:   if variadic2 is not None: results.append(variadic2)
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
@@ -109,6 +111,7 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   regions = None
   // CHECK:   attributes["i32attr"] = i32attr
   // CHECK:   if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr
   // CHECK:   if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
@@ -117,7 +120,7 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def i32attr(self):
@@ -150,6 +153,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   regions = None
   // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_0))
   // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_2))
   // CHECK:   if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
@@ -158,7 +162,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def in_(self):
@@ -181,10 +185,11 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   regions = None
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -194,6 +199,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   regions = None
   // CHECK:   results.append(i32)
   // CHECK:   results.append(_gen_res_1)
   // CHECK:   results.append(i64)
@@ -203,7 +209,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def f32(self):
@@ -230,12 +236,13 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   regions = None
   // CHECK:   operands.append(_get_op_result_or_value(non_variadic))
   // CHECK:   operands.extend(_get_op_results_or_values(variadic))
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
@@ -258,12 +265,13 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   regions = None
   // CHECK:   results.extend(variadic)
   // CHECK:   results.append(non_variadic)
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def variadic(self):
@@ -285,11 +293,12 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   regions = None
   // CHECK:   operands.append(_get_op_result_or_value(in_))
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def in_(self):
@@ -351,6 +360,7 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
+  // CHECK:   regions = None
   // CHECK:   results.append(i64)
   // CHECK:   results.append(f64)
   // CHECK:   operands.append(_get_op_result_or_value(i32))
@@ -358,7 +368,7 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
-  // CHECK:     successors=_ods_successors, loc=loc, ip=ip))
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
   // CHECK: @builtins.property
   // CHECK: def i32(self):
@@ -379,6 +389,50 @@ def SimpleOp : TestOp<"simple"> {
   let results = (outs I64:$i64, F64:$f64);
 }
 
+// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
+def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
+  // CHECK:  def __init__(self, num_variadic, *, loc=None, ip=None):
+  // CHECK:    operands = []
+  // CHECK:    results = []
+  // CHECK:    attributes = {}
+  // CHECK:    regions = None
+  // CHECK:    _ods_successors = None
+  // CHECK:    regions = 2 + num_variadic
+  // CHECK:    super().__init__(self.build_generic(
+  // CHECK:      attributes=attributes, results=results, operands=operands,
+  // CHECK:      successors=_ods_successors, regions=regions, loc=loc, ip=ip))
+  let regions = (region AnyRegion:$region, AnyRegion, VariadicRegion<AnyRegion>:$variadic);
+
+  // CHECK:  @builtins.property
+  // CHECK:  def region():
+  // CHECK:    return self.regions[0]
+
+  // CHECK:  @builtins.property
+  // CHECK:  def variadic():
+  // CHECK:    return self.regions[2:]
+}
+
+// CHECK: class VariadicRegionOp(_ods_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
+def VariadicRegionOp : TestOp<"variadic_region"> {
+  // CHECK:  def __init__(self, num_variadic, *, loc=None, ip=None):
+  // CHECK:    operands = []
+  // CHECK:    results = []
+  // CHECK:    attributes = {}
+  // CHECK:    regions = None
+  // CHECK:    _ods_successors = None
+  // CHECK:    regions = 0 + num_variadic
+  // CHECK:    super().__init__(self.build_generic(
+  // CHECK:      attributes=attributes, results=results, operands=operands,
+  // CHECK:      successors=_ods_successors, regions=regions, loc=loc, ip=ip))
+  let regions = (region VariadicRegion<AnyRegion>:$Variadic);
+
+  // CHECK:  @builtins.property
+  // CHECK:  def Variadic():
+  // CHECK:    return self.regions[0:]
+}
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.with_successors"
@@ -390,3 +444,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
   let successors = (successor AnySuccessor:$successor,
                               VariadicSuccessor<AnySuccessor>:$successors);
 }
+

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 51993bbf6b051..039827ecf0839 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -252,6 +252,12 @@ constexpr const char *attributeDeleterTemplate = R"Py(
     del self.operation.attributes["{1}"]
 )Py";
 
+constexpr const char *regionAccessorTemplate = R"PY(
+  @builtins.property
+  def {0}():
+    return self.regions[{1}]
+)PY";
+
 static llvm::cl::OptionCategory
     clOpPythonBindingCat("Options for -gen-python-op-bindings");
 
@@ -482,10 +488,11 @@ constexpr const char *initTemplate = R"Py(
     operands = []
     results = []
     attributes = {{}
+    regions = None
     {1}
     super().__init__(self.build_generic(
       attributes=attributes, results=results, operands=operands,
-      successors=_ods_successors, loc=loc, ip=ip))
+      successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 )Py";
 
 /// Template for appending a single element to the operand/result list.
@@ -697,6 +704,30 @@ populateBuilderLinesResult(const Operator &op,
   }
 }
 
+/// If the operation has variadic regions, adds a builder argument to specify
+/// the number of those regions and builder lines to forward it to the generic
+/// constructor.
+static void
+populateBuilderRegions(const Operator &op,
+                       llvm::SmallVectorImpl<std::string> &builderArgs,
+                       llvm::SmallVectorImpl<std::string> &builderLines) {
+  if (op.hasNoVariadicRegions())
+    return;
+
+  // This is currently enforced when Operator is constructed.
+  assert(op.getNumVariadicRegions() == 1 &&
+         op.getRegion(op.getNumRegions() - 1).isVariadic() &&
+         "expected the last region to be varidic");
+
+  const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
+  std::string name =
+      ("num_" + region.name.take_front().lower() + region.name.drop_front())
+          .str();
+  builderArgs.push_back(name);
+  builderLines.push_back(
+      llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
+}
+
 /// Emits a default builder constructing an operation from the list of its
 /// result types, followed by a list of its operands.
 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
@@ -720,6 +751,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
       op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
       builderLines);
   populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
+  populateBuilderRegions(op, builderArgs, builderLines);
 
   builderArgs.push_back("*");
   builderArgs.push_back("loc=None");
@@ -767,6 +799,21 @@ static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
                       op.hasNoVariadicRegions() ? "True" : "False");
 }
 
+/// Emits named accessors to regions.
+static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
+  for (auto en : llvm::enumerate(op.getRegions())) {
+    const NamedRegion &region = en.value();
+    if (region.name.empty())
+      continue;
+
+    assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
+           "expected only the last region to be variadic");
+    os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
+                        std::to_string(en.index()) +
+                            (region.isVariadic() ? ":" : ""));
+  }
+}
+
 /// Emits bindings for a specific Op to the given output stream.
 static void emitOpBindings(const Operator &op,
                            const AttributeClasses &attributeClasses,
@@ -787,6 +834,7 @@ static void emitOpBindings(const Operator &op,
   emitOperandAccessors(op, os);
   emitAttributeAccessors(op, attributeClasses, os);
   emitResultAccessors(op, os);
+  emitRegionAccessors(op, os);
 }
 
 /// Emits bindings for the dialect specified in the command line, including file


        


More information about the Mlir-commits mailing list