[Mlir-commits] [mlir] 9b79f50 - [mlir][tblgen][ods][python] Use keyword-only arguments for optional builder arguments in generated Python bindings

Stella Laurenzo llvmlistbot at llvm.org
Sun May 22 21:24:20 PDT 2022


Author: Jeremy Furtek
Date: 2022-05-21T21:18:53-07:00
New Revision: 9b79f50b59c671f3a7b0c7654cdd6d59ab6a6bbc

URL: https://github.com/llvm/llvm-project/commit/9b79f50b59c671f3a7b0c7654cdd6d59ab6a6bbc
DIFF: https://github.com/llvm/llvm-project/commit/9b79f50b59c671f3a7b0c7654cdd6d59ab6a6bbc.diff

LOG: [mlir][tblgen][ods][python] Use keyword-only arguments for optional builder arguments in generated Python bindings

This diff modifies `mlir-tblgen` to generate Python Operation class `__init__()`
functions that use Python keyword-only arguments.

Previously, all `__init__()` function arguments were positional. Python code to
create MLIR Operations was required to provide values for ALL builder arguments,
including optional arguments (attributes and operands). Callers that did not
provide, for example, an optional attribute would be forced to provide `None`
as an argument for EACH optional attribute. Proposed changes in this diff use
`tblgen` record information (as provided by ODS) to generate keyword arguments
for:
- optional operands
- optional attributes (which includes unit attributes)
- default-valued attributes

These `__init__()` function keyword arguments have default `None` values (i.e.
the argument form is `optionalAttr=None`), allowing callers to create Operations
more easily.

Note that since optional arguments become keyword-only arguments (since they are
placed after the bare `*` argument), this diff will require ALL optional
operands and attributes to be provided using explicit keyword syntax. This may,
in the short term, break any out-of-tree Python code that provided values via
positional arguments. However, in the long term, it seems that requiring
keywords for optional arguments will be more robust to operation changes that
add arguments.

Tests were modified to reflect the updated Operation builder calling convention.

This diff partially addresses the requests made in the github issue below.

https://github.com/llvm/llvm-project/issues/54932

Reviewed By: stellaraccident, mikeurbach

Differential Revision: https://reviews.llvm.org/D124717

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_func_ops_ext.py
    mlir/python/mlir/dialects/_ml_program_ops_ext.py
    mlir/python/mlir/dialects/_pdl_ops_ext.py
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/test/python/dialects/python_test.py
    mlir/test/python/dialects/vector.py
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py
index 6fe3ff5302e26..79577463d9199 100644
--- a/mlir/python/mlir/dialects/_func_ops_ext.py
+++ b/mlir/python/mlir/dialects/_func_ops_ext.py
@@ -58,7 +58,7 @@ def __init__(self,
     type = TypeAttr.get(type)
     sym_visibility = StringAttr.get(
         str(visibility)) if visibility is not None else None
-    super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip)
+    super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
     if body_builder:
       entry_block = self.add_entry_block()
       with InsertionPoint(entry_block):

diff  --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py
index a3df7ff033607..8db82cf81c678 100644
--- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py
+++ b/mlir/python/mlir/dialects/_ml_program_ops_ext.py
@@ -48,7 +48,7 @@ def __init__(self,
     type = TypeAttr.get(type)
     sym_visibility = StringAttr.get(
         str(visibility)) if visibility is not None else None
-    super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip)
+    super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
     if body_builder:
       entry_block = self.add_entry_block()
       with InsertionPoint(entry_block):

diff  --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py
index fb5b519c7c022..bb63fe64dd035 100644
--- a/mlir/python/mlir/dialects/_pdl_ops_ext.py
+++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py
@@ -93,7 +93,7 @@ def __init__(self,
                ip=None):
     type = type if type is None else _get_value(type)
     result = pdl.AttributeType.get()
-    super().__init__(result, type, value, loc=loc, ip=ip)
+    super().__init__(result, type=type, value=value, loc=loc, ip=ip)
 
 
 class EraseOp:
@@ -118,7 +118,7 @@ def __init__(self,
                ip=None):
     type = type if type is None else _get_value(type)
     result = pdl.ValueType.get()
-    super().__init__(result, type, loc=loc, ip=ip)
+    super().__init__(result, type=type, loc=loc, ip=ip)
 
 
 class OperandsOp:
@@ -131,7 +131,7 @@ def __init__(self,
                ip=None):
     types = types if types is None else _get_value(types)
     result = pdl.RangeType.get(pdl.ValueType.get())
-    super().__init__(result, types, loc=loc, ip=ip)
+    super().__init__(result, type=types, loc=loc, ip=ip)
 
 
 class OperationOp:
@@ -155,7 +155,7 @@ def __init__(self,
     attributeNames = ArrayAttr.get(attributeNames)
     types = _get_values(types)
     result = pdl.OperationType.get()
-    super().__init__(result, name, args, attributeValues, attributeNames, types, loc=loc, ip=ip)
+    super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip)
 
 
 class PatternOp:
@@ -170,7 +170,7 @@ def __init__(self,
     """Creates an PDL `pattern` operation."""
     name_attr = None if name is None else _get_str_attr(name)
     benefit_attr = _get_int_attr(16, benefit)
-    super().__init__(benefit_attr, name_attr, loc=loc, ip=ip)
+    super().__init__(benefit_attr, sym_name=name_attr, loc=loc, ip=ip)
     self.regions[0].blocks.append()
 
   @property
@@ -192,7 +192,7 @@ def __init__(self,
     op = _get_value(op)
     with_op = with_op if with_op is None else _get_value(with_op)
     with_values = _get_values(with_values)
-    super().__init__(op, with_op, with_values, loc=loc, ip=ip)
+    super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
 
 
 class ResultOp:
@@ -222,7 +222,7 @@ def __init__(self,
                ip=None):
     parent = _get_value(parent)
     index = index if index is None else _get_int_attr(32, index)
-    super().__init__(result, parent, index, loc=loc, ip=ip)
+    super().__init__(result, parent, index=index, loc=loc, ip=ip)
 
 
 class RewriteOp:
@@ -238,7 +238,7 @@ def __init__(self,
     root = root if root is None else _get_value(root)
     name = name if name is None else _get_str_attr(name)
     args = _get_values(args)
-    super().__init__(root, name, args, loc=loc, ip=ip)
+    super().__init__(args, root=root,name=name, loc=loc, ip=ip)
 
   def add_body(self):
     """Add body (block) to the rewrite."""
@@ -261,7 +261,7 @@ def __init__(self,
                ip=None):
     type = type if type is None else _get_type_attr(type)
     result = pdl.TypeType.get()
-    super().__init__(result, type, loc=loc, ip=ip)
+    super().__init__(result, type=type, loc=loc, ip=ip)
 
 
 class TypesOp:
@@ -275,4 +275,4 @@ def __init__(self,
     types = _get_array_attr([_get_type_attr(ty) for ty in types])
     types = None if not types else types
     result = pdl.RangeType.get(pdl.TypeType.get())
-    super().__init__(result, types, loc=loc, ip=ip)
+    super().__init__(result, types=types, loc=loc, ip=ip)

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 59b5dec83c030..f744ce501b106 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -21,7 +21,7 @@ class TestOp<string mnemonic, list<Trait> traits = []> :
 // CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,0,]
 def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
                                  [AttrSizedOperandSegments]> {
-  // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
+  // CHECK: def __init__(self, variadic1, non_variadic, *, variadic2=None, loc=None, ip=None):
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
@@ -110,7 +110,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
 // CHECK-NOT: _ODS_OPERAND_SEGMENTS
 // CHECK-NOT: _ODS_RESULT_SEGMENTS
 def AttributedOp : TestOp<"attributed_op"> {
-  // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, *, loc=None, ip=None):
+  // CHECK: def __init__(self, i32attr, in_, *, optionalF32Attr=None, unitAttr=None, loc=None, ip=None):
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
@@ -152,7 +152,7 @@ def AttributedOp : TestOp<"attributed_op"> {
 // CHECK-NOT: _ODS_OPERAND_SEGMENTS
 // CHECK-NOT: _ODS_RESULT_SEGMENTS
 def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
-  // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, *, loc=None, ip=None):
+  // CHECK: def __init__(self, _gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None):
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}
@@ -286,7 +286,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
 // CHECK-NOT: _ODS_RESULT_SEGMENTS
 def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
   let arguments = (ins AnyType:$non_optional, Optional<AnyType>:$optional);
-  // CHECK: def __init__(self, non_optional, optional, *, loc=None, ip=None):
+  // CHECK: def __init__(self, non_optional, *, optional=None, loc=None, ip=None):
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   attributes = {}

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index e7b1f44a3ad8e..c73fce23d3c49 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -28,13 +28,13 @@ def testAttributes():
     # CHECK-DAG: optional_i32 = 2 : i32
     # CHECK-DAG: unit
     # CHECK: }
-    op = test.AttributedOp(one, two, unit)
+    op = test.AttributedOp(one, optional_i32=two, unit=unit)
     print(f"{op}")
 
     # CHECK: "python_test.attributed_op"() {
     # CHECK: mandatory_i32 = 2 : i32
     # CHECK: }
-    op2 = test.AttributedOp(two, None, None)
+    op2 = test.AttributedOp(two)
     print(f"{op2}")
 
     #
@@ -218,11 +218,11 @@ def testOptionalOperandOp():
     module = Module.create()
     with InsertionPoint(module.body):
 
-      op1 = test.OptionalOperandOp(None)
+      op1 = test.OptionalOperandOp()
       # CHECK: op1.input is None: True
       print(f"op1.input is None: {op1.input is None}")
 
-      op2 = test.OptionalOperandOp(op1)
+      op2 = test.OptionalOperandOp(input=op1)
       # CHECK: op2.input is None: False
       print(f"op2.input is None: {op2.input is None}")
 

diff  --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py
index c31579545e6e7..8f8d7f19191cf 100644
--- a/mlir/test/python/dialects/vector.py
+++ b/mlir/test/python/dialects/vector.py
@@ -46,9 +46,9 @@ def testTransferReadOp():
     with InsertionPoint(f.add_entry_block()):
       A, zero, padding, mask = f.arguments
       vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
-                            padding, mask, None)
+                            padding, mask=mask)
       vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
-                            padding, None, None)
+                            padding)
       func.ReturnOp([])
 
   # CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 16fccff973ca7..83d2acce3ba2c 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -620,6 +620,13 @@ populateBuilderArgs(const Operator &op,
     if (!op.getArg(i).is<NamedAttribute *>())
       operandNames.push_back(name);
   }
+}
+
+/// Populates `builderArgs` with the Python-compatible names of builder function
+/// successor arguments. Additionally, `successorArgNames` is also populated.
+static void populateBuilderArgsSuccessors(
+    const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
+    llvm::SmallVectorImpl<std::string> &successorArgNames) {
 
   for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
     NamedSuccessor successor = op.getSuccessor(i);
@@ -857,6 +864,8 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
   populateBuilderArgsResults(op, builderArgs);
   size_t numResultArgs = builderArgs.size();
   populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
+  size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
+  populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
 
   populateBuilderLinesOperand(op, operandArgNames, builderLines);
   populateBuilderLinesAttr(
@@ -868,10 +877,53 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
   populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
   populateBuilderRegions(op, builderArgs, builderLines);
 
-  builderArgs.push_back("*");
-  builderArgs.push_back("loc=None");
-  builderArgs.push_back("ip=None");
-  os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),
+  // Layout of builderArgs vector elements:
+  // [ result_args  operand_attr_args successor_args regions ]
+
+  // Determine whether the argument corresponding to a given index into the
+  // builderArgs vector is a python keyword argument or not.
+  auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
+    // All result, successor, and region arguments are positional arguments.
+    if ((builderArgIndex < numResultArgs) ||
+        (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
+      return false;
+    // Keyword arguments:
+    // - optional named attributes (including unit attributes)
+    // - default-valued named attributes
+    // - optional operands
+    Argument a = op.getArg(builderArgIndex - numResultArgs);
+    if (auto *nattr = a.dyn_cast<NamedAttribute *>())
+      return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
+    else if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>())
+      return ntype->isOptional();
+    else
+      return false;
+  };
+
+  // StringRefs in functionArgs refer to strings allocated by builderArgs.
+  llvm::SmallVector<llvm::StringRef> functionArgs;
+
+  // Add positional arguments.
+  for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
+    if (!isKeywordArgFn(i))
+      functionArgs.push_back(builderArgs[i]);
+  }
+
+  // Add a bare '*' to indicate that all following arguments must be keyword
+  // arguments.
+  functionArgs.push_back("*");
+
+  // Add a default 'None' value to each keyword arg string, and then add to the
+  // function args list.
+  for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
+    if (isKeywordArgFn(i)) {
+      builderArgs[i].append("=None");
+      functionArgs.push_back(builderArgs[i]);
+    }
+  }
+  functionArgs.push_back("loc=None");
+  functionArgs.push_back("ip=None");
+  os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
                       llvm::join(builderLines, "\n    "));
 }
 


        


More information about the Mlir-commits mailing list