[Mlir-commits] [mlir] f9265de - [mlir] Generate Op builders for Python bindings

Alex Zinenko llvmlistbot at llvm.org
Thu Nov 12 02:29:32 PST 2020


Author: Alex Zinenko
Date: 2020-11-12T11:29:23+01:00
New Revision: f9265de8c634798b2ae8b4bdad7c2f5b7442115e

URL: https://github.com/llvm/llvm-project/commit/f9265de8c634798b2ae8b4bdad7c2f5b7442115e
DIFF: https://github.com/llvm/llvm-project/commit/f9265de8c634798b2ae8b4bdad7c2f5b7442115e.diff

LOG: [mlir] Generate Op builders for Python bindings

Add an ODS-backed generator of default builders. This currently does not
support operation with attribute arguments, for which the builder is
just ignored. Attribute support will be introduced separately for
builders and accessors.

Default builders are always generated with the same number of result and
operand groups as the ODS specification, i.e. one group per each operand
or result. Optional elements accept None but cannot be omitted. Variadic
groups accept iterable objects and cannot be replaced with a single
object.

For some operations, it is possible to infer the result type given the
traits, but most traits rely on inline pieces of C++ that we cannot
(yet) forward to Python bindings. Since the Ops where the inference is
possible (having the `SameOperandAndResultTypes` trait or
`TypeMatchesWith` without transform field) are a small minority, they
also require the result type to make the builder syntax more consistent.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D91190

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/test/Bindings/Python/dialects.py
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 652efa70fe06..b7b03f71ddd2 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -106,7 +106,7 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
                             SameOperandsAndResultType,
                             ElementwiseMappable])> {
 
-  let results = (outs AnyType);
+  let results = (outs AnyType:$result);
 
   let parser = [{
     return impl::parseOneResultSameOperandTypeOp(parser, result);

diff  --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py
index 63ec61456a58..82e25b9e54fa 100644
--- a/mlir/test/Bindings/Python/dialects.py
+++ b/mlir/test/Bindings/Python/dialects.py
@@ -63,7 +63,7 @@ def testUserDialectClass():
 run(testUserDialectClass)
 
 
-# XHECK-LABEL: TEST: testCustomOpView
+# CHECK-LABEL: TEST: testCustomOpView
 # This test uses the standard dialect AddFOp as an example of a user op.
 # TODO: Op creation and access is still quite verbose: simplify this test as
 # additional capabilities come online.
@@ -82,17 +82,17 @@ def createInput():
       # Create via dialects context collection.
       input1 = createInput()
       input2 = createInput()
-      op1 = ctx.dialects.std.AddFOp(input1, input2)
+      op1 = ctx.dialects.std.AddFOp(input1.type, input1, input2)
 
       # Create via an import
       from mlir.dialects.std import AddFOp
-      AddFOp(input1, op1.result)
+      AddFOp(input1.type, input1, op1.result)
 
-  # XHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
-  # XHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
-  # XHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
-  # XHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
+  # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
+  # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
+  # CHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
+  # CHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
   m.operation.print()
 
-# TODO: re-enable when constructs are generated again
-# run(testCustomOpView)
+
+run(testCustomOpView)

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 3d19379978d7..77af112a9b0e 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -18,6 +18,23 @@ class TestOp<string mnemonic, list<OpTrait> traits = []> :
 // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
 def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
                                  [AttrSizedOperandSegments]> {
+  // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   operand_segment_sizes = array.array('L')
+  // CHECK:   operands += [*variadic1]
+  // CHECK:   operand_segment_sizes.append(len(variadic1))
+  // CHECK:   operands.append(non_variadic)
+  // CHECK:   operand_segment_sizes.append(1)
+  // CHECK:   if variadic2 is not None: operands.append(variadic2)
+  // CHECK:   operand_segment_sizes.append(0 if variadic2 is None else 1)
+  // CHECK:   attributes["operand_segment_sizes"] = _ir.DenseElementsAttr.get(operand_segment_sizes,
+  // CHECK:       context=Location.current.context if loc is None else loc.context)
+  // CHECK:   super().__init__(_ir.Operation.create(
+  // CHECK:     "test.attr_sized_operands", attributes=attributes, operands=operands, results=results,
+  // CHECK:     loc=loc, ip=ip))
+
   // CHECK: @property
   // CHECK: def variadic1(self):
   // CHECK:   operand_range = _segmented_accessor(
@@ -47,6 +64,23 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
 // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
 def AttrSizedResultsOp : TestOp<"attr_sized_results",
                                [AttrSizedResultSegments]> {
+  // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   result_segment_sizes = array.array('L')
+  // CHECK:   if variadic1 is not None: results.append(variadic1)
+  // CHECK:   result_segment_sizes.append(0 if variadic1 is None else 1)
+  // CHECK:   results.append(non_variadic)
+  // CHECK:   result_segment_sizes.append(1) # non_variadic
+  // CHECK:   if variadic2 is not None: results.append(variadic2)
+  // CHECK:   result_segment_sizes.append(0 if variadic2 is None else 1)
+  // CHECK:   attributes["result_segment_sizes"] = _ir.DenseElementsAttr.get(result_segment_sizes,
+  // CHECK:       context=Location.current.context if loc is None else loc.context)
+  // CHECK:   super().__init__(_ir.Operation.create(
+  // CHECK:     "test.attr_sized_results", attributes=attributes, operands=operands, results=results,
+  // CHECK:     loc=loc, ip=ip))
+
   // CHECK: @property
   // CHECK: def variadic1(self):
   // CHECK:   result_range = _segmented_accessor(
@@ -75,11 +109,32 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
 // CHECK: class EmptyOp(_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.empty"
 def EmptyOp : TestOp<"empty">;
+  // CHECK: def __init__(self, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   super().__init__(_ir.Operation.create(
+  // CHECK:     "test.empty", attributes=attributes, operands=operands, results=results,
+  // CHECK:     loc=loc, ip=ip))
 
 // CHECK: @_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.missing_names"
 def MissingNamesOp : TestOp<"missing_names"> {
+  // CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   results.append(i32)
+  // CHECK:   results.append(_gen_res_1)
+  // CHECK:   results.append(i64)
+  // CHECK:   operands.append(_gen_arg_0)
+  // CHECK:   operands.append(f32)
+  // CHECK:   operands.append(_gen_arg_2)
+  // CHECK:   super().__init__(_ir.Operation.create(
+  // CHECK:     "test.missing_names", attributes=attributes, operands=operands, results=results,
+  // CHECK:     loc=loc, ip=ip))
+
   // CHECK: @property
   // CHECK: def f32(self):
   // CHECK:   return self.operation.operands[1]
@@ -99,6 +154,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
 // CHECK: class OneVariadicOperandOp(_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
 def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
+  // CHECK: def __init__(self, non_variadic, variadic, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   operands.append(non_variadic)
+  // CHECK:   operands += [*variadic]
+  // CHECK:   super().__init__(_ir.Operation.create(
+  // CHECK:     "test.one_variadic_operand", attributes=attributes, operands=operands, results=results,
+  // CHECK:     loc=loc, ip=ip))
+
   // CHECK: @property
   // CHECK: def non_variadic(self):
   // CHECK:   return self.operation.operands[0]
@@ -114,6 +179,16 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
 // CHECK: class OneVariadicResultOp(_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
 def OneVariadicResultOp : TestOp<"one_variadic_result"> {
+  // CHECK: def __init__(self, variadic, non_variadic, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   results += [*variadic]
+  // CHECK:   results.append(non_variadic)
+  // CHECK:   super().__init__(_ir.Operation.create(
+  // CHECK:     "test.one_variadic_result", attributes=attributes, operands=operands, results=results,
+  // CHECK:     loc=loc, ip=ip))
+
   // CHECK: @property
   // CHECK: def variadic(self):
   // CHECK:   variadic_group_length = len(self.operation.results) - 2 + 1
@@ -130,6 +205,15 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
 // CHECK: class PythonKeywordOp(_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
 def PythonKeywordOp : TestOp<"python_keyword"> {
+  // CHECK: def __init__(self, in_, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   operands.append(in_)
+  // CHECK:   super().__init__(_ir.Operation.create(
+  // CHECK:     "test.python_keyword", attributes=attributes, operands=operands, results=results,
+  // CHECK:     loc=loc, ip=ip))
+
   // CHECK: @property
   // CHECK: def in_(self):
   // CHECK:   return self.operation.operands[0]
@@ -186,6 +270,18 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
 // CHECK: class SimpleOp(_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.simple"
 def SimpleOp : TestOp<"simple"> {
+  // CHECK: def __init__(self, i64, f64, i32, f32, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   results.append(i64)
+  // CHECK:   results.append(f64)
+  // CHECK:   operands.append(i32)
+  // CHECK:   operands.append(f32)
+  // CHECK:   super().__init__(_ir.Operation.create(
+  // CHECK:     "test.simple", attributes=attributes, operands=operands, results=results,
+  // CHECK:     loc=loc, ip=ip))
+
   // CHECK: @property
   // CHECK: def i32(self):
   // CHECK:   return self.operation.operands[0]

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index f940aae38176..e32924451234 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -26,6 +26,7 @@ using namespace mlir::tblgen;
 constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
+import array
 from . import _cext
 from . import _segmented_accessor, _equally_sized_accessor
 _ir = _cext.ir
@@ -172,6 +173,12 @@ static std::string sanitizeName(StringRef name) {
   return name.str();
 }
 
+static std::string attrSizedTraitForKind(const char *kind) {
+  return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
+                       llvm::StringRef(kind).take_front().upper(),
+                       llvm::StringRef(kind).drop_front());
+}
+
 /// Emits accessors to "elements" of an Op definition. Currently, the supported
 /// elements are operands and results, indicated by `kind`, which must be either
 /// `operand` or `result` and is used verbatim in the emitted code.
@@ -190,10 +197,7 @@ static void emitElementAccessors(
       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
                     llvm::StringRef(kind).take_front().upper(),
                     llvm::StringRef(kind).drop_front());
-  std::string attrSizedTrait =
-      llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
-                    llvm::StringRef(kind).take_front().upper(),
-                    llvm::StringRef(kind).drop_front());
+  std::string attrSizedTrait = attrSizedTraitForKind(kind);
 
   unsigned numVariadic = getNumVariadic(op);
 
@@ -271,20 +275,23 @@ static void emitElementAccessors(
   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
 }
 
+/// Free function helpers accessing Operator components.
+static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
+static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
+  return op.getOperand(i);
+}
+static int getNumResults(const Operator &op) { return op.getNumResults(); }
+static const NamedTypeConstraint &getResult(const Operator &op, int i) {
+  return op.getResult(i);
+}
+
 /// Emits accessor to Op operands.
 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
   auto getNumVariadic = [](const Operator &oper) {
     return oper.getNumVariableLengthOperands();
   };
-  auto getNumElements = [](const Operator &oper) {
-    return oper.getNumOperands();
-  };
-  auto getElement = [](const Operator &oper,
-                       int i) -> const NamedTypeConstraint & {
-    return oper.getOperand(i);
-  };
-  emitElementAccessors(op, os, "operand", getNumVariadic, getNumElements,
-                       getElement);
+  emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
+                       getOperand);
 }
 
 /// Emits access or Op results.
@@ -292,21 +299,152 @@ static void emitResultAccessors(const Operator &op, raw_ostream &os) {
   auto getNumVariadic = [](const Operator &oper) {
     return oper.getNumVariableLengthResults();
   };
-  auto getNumElements = [](const Operator &oper) {
-    return oper.getNumResults();
-  };
-  auto getElement = [](const Operator &oper,
-                       int i) -> const NamedTypeConstraint & {
-    return oper.getResult(i);
-  };
-  emitElementAccessors(op, os, "result", getNumVariadic, getNumElements,
-                       getElement);
+  emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
+                       getResult);
+}
+
+/// Template for the default auto-generated builder.
+///   {0} is the operation name;
+///   {1} is a comma-separated list of builder arguments, including the trailing
+///       `loc` and `ip`;
+///   {2} is the code populating `operands`, `results` and `attributes` fields.
+constexpr const char *initTemplate = R"Py(
+  def __init__(self, {1}):
+    operands = []
+    results = []
+    attributes = {{}
+    {2}
+    super().__init__(_ir.Operation.create(
+      "{0}", attributes=attributes, operands=operands, results=results,
+      loc=loc, ip=ip))
+)Py";
+
+/// Template for appending a single element to the operand/result list.
+///   {0} is either 'operand' or 'result';
+///   {1} is the field name.
+constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
+
+/// Template for appending an optional element to the operand/result list.
+///   {0} is either 'operand' or 'result';
+///   {1} is the field name.
+constexpr const char *optionalAppendTemplate =
+    "if {1} is not None: {0}s.append({1})";
+
+/// Template for appending a variadic element to the operand/result list.
+///   {0} is either 'operand' or 'result';
+///   {1} is the field name.
+constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]";
+
+/// Template for setting up the segment sizes buffer.
+constexpr const char *segmentDeclarationTemplate =
+    "{0}_segment_sizes = array.array('L')";
+
+/// Template for attaching segment sizes to the attribute list.
+constexpr const char *segmentAttributeTemplate =
+    R"Py(attributes["{0}_segment_sizes"] = _ir.DenseElementsAttr.get({0}_segment_sizes,
+      context=Location.current.context if loc is None else loc.context))Py";
+
+/// Template for appending the unit size to the segment sizes.
+///   {0} is either 'operand' or 'result';
+///   {1} is the field name.
+constexpr const char *singleElementSegmentTemplate =
+    "{0}_segment_sizes.append(1) # {1}";
+
+/// Template for appending 0/1 for an optional element to the segment sizes.
+///   {0} is either 'operand' or 'result';
+///   {1} is the field name.
+constexpr const char *optionalSegmentTemplate =
+    "{0}_segment_sizes.append(0 if {1} is None else 1)";
+
+/// Template for appending the length of a variadic group to the segment sizes.
+///   {0} is either 'operand' or 'result';
+///   {1} is the field name.
+constexpr const char *variadicSegmentTemplate =
+    "{0}_segment_sizes.append(len({1}))";
+
+/// Populates `builderArgs` with the list of `__init__` arguments that
+/// correspond to either operands or results of `op`, and `builderLines` with
+/// additional lines that are required in the builder. `kind` must be either
+/// "operand" or "result". `unnamedTemplate` is used to generate names for
+/// operands or results that don't have the name in ODS.
+static void populateBuilderLines(
+    const Operator &op, const char *kind, const char *unnamedTemplate,
+    llvm::SmallVectorImpl<std::string> &builderArgs,
+    llvm::SmallVectorImpl<std::string> &builderLines,
+    llvm::function_ref<int(const Operator &)> getNumElements,
+    llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
+        getElement) {
+  // The segment sizes buffer only has to be populated if there attr-sized
+  // segments trait is present.
+  bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
+  if (includeSegments)
+    builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind));
+
+  // For each element, find or generate a name.
+  for (int i = 0, e = getNumElements(op); i < e; ++i) {
+    const NamedTypeConstraint &element = getElement(op, i);
+    std::string name = element.name.str();
+    if (name.empty())
+      name = llvm::formatv(unnamedTemplate, i).str();
+    name = sanitizeName(name);
+    builderArgs.push_back(name);
+
+    // Choose the formatting string based on the element kind.
+    llvm::StringRef formatString, segmentFormatString;
+    if (!element.isVariableLength()) {
+      formatString = singleElementAppendTemplate;
+      segmentFormatString = singleElementSegmentTemplate;
+    } else if (element.isOptional()) {
+      formatString = optionalAppendTemplate;
+      segmentFormatString = optionalSegmentTemplate;
+    } else {
+      assert(element.isVariadic() && "unhandled element group type");
+      formatString = variadicAppendTemplate;
+      segmentFormatString = variadicSegmentTemplate;
+    }
+
+    // Add the lines.
+    builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
+    if (includeSegments)
+      builderLines.push_back(
+          llvm::formatv(segmentFormatString.data(), kind, name));
+  }
+
+  if (includeSegments)
+    builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind));
+}
+
+/// Emits a default builder constructing an operation from the list of its
+/// result types, followed by a list of its operands.
+static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
+  // TODO: support attribute types.
+  if (op.getNumNativeAttributes() != 0)
+    return;
+
+  // If we are asked to skip default builders, comply.
+  if (op.skipDefaultBuilders())
+    return;
+
+  llvm::SmallVector<std::string, 8> builderArgs;
+  llvm::SmallVector<std::string, 8> builderLines;
+  builderArgs.reserve(op.getNumOperands() + op.getNumResults());
+  populateBuilderLines(op, "result", "_gen_res_{0}", builderArgs, builderLines,
+                       getNumResults, getResult);
+  populateBuilderLines(op, "operand", "_gen_arg_{0}", builderArgs, builderLines,
+                       getNumOperands, getOperand);
+
+  builderArgs.push_back("loc=None");
+  builderArgs.push_back("ip=None");
+  os << llvm::formatv(initTemplate, op.getOperationName(),
+                      llvm::join(builderArgs, ", "),
+                      llvm::join(builderLines, "\n    "));
 }
 
 /// Emits bindings for a specific Op to the given output stream.
 static void emitOpBindings(const Operator &op, raw_ostream &os) {
   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
                       op.getOperationName());
+  emitDefaultOpBuilder(op, os);
   emitOperandAccessors(op, os);
   emitResultAccessors(op, os);
 }


        


More information about the Mlir-commits mailing list