[Mlir-commits] [mlir] 989d2b5 - [mlir][tablegen] Generate default attr values in Python bindings

Matthias Springer llvmlistbot at llvm.org
Wed Jun 15 07:43:00 PDT 2022


Author: Matthias Springer
Date: 2022-06-15T16:40:27+02:00
New Revision: 989d2b518638616e3777d2c7fd3cca1481940937

URL: https://github.com/llvm/llvm-project/commit/989d2b518638616e3777d2c7fd3cca1481940937
DIFF: https://github.com/llvm/llvm-project/commit/989d2b518638616e3777d2c7fd3cca1481940937.diff

LOG: [mlir][tablegen] Generate default attr values in Python bindings

When specifying an op attribute with a default value (via DefaultValuedAttr), the default value is a string of C++ code. In the general case, the default value of such an attribute cannot be translated to Python when generating the bindings. However, we can hard-code default Python values for frequently-used C++ default values.

This change adds a Python default value for empty ArrayAttrs.

Differential Revision: https://reviews.llvm.org/D127750

Added: 
    

Modified: 
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index f744ce501b106..2b73132a7431d 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -179,6 +179,27 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
 }
 
+// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs"
+def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
+  // CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   regions = None
+  // CHECK:   attributes["arr"] = arr if arr is not None else _ods_ir.ArrayAttr.get([])
+  // CHECK:   unsupported is not None, "attribute unsupported must be specified"
+  // CHECK:   _ods_successors = None
+  // CHECK:   super().__init__(self.build_generic(
+  // CHECK:     attributes=attributes, results=results, operands=operands,
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
+
+  let arguments = (ins DefaultValuedAttr<I64ArrayAttr, "{}">:$arr,
+                       DefaultValuedAttr<I64ArrayAttr, "dummy_func()">:$unsupported);
+  let results = (outs);
+}
+
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
   // CHECK: def __init__(self, type, *, loc=None, ip=None):
@@ -544,4 +565,3 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
   let successors = (successor AnySuccessor:$successor,
                               VariadicSuccessor<AnySuccessor>:$successors);
 }
-

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 34db5bf46953b..e40d0ff8faf21 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/StringSet.h"
@@ -542,6 +543,21 @@ constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
 constexpr const char *initOptionalAttributeTemplate =
     R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
 
+/// Template for setting an attribute with a default value in the operation
+/// builder.
+///   {0} is the attribute name;
+///   {1} is the builder argument name;
+///   {2} is the default value.
+constexpr const char *initDefaultValuedAttributeTemplate =
+    R"Py(attributes["{0}"] = {1} if {1} is not None else {2})Py";
+
+/// Template for asserting that an attribute value was provided when calling a
+/// builder.
+///   {0} is the attribute name;
+///   {1} is the builder argument name.
+constexpr const char *assertAttributeValueSpecified =
+    R"Py(assert {1} is not None, "attribute {0} must be specified")Py";
+
 constexpr const char *initUnitAttributeTemplate =
     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
       _ods_get_default_loc_context(loc)))Py";
@@ -647,6 +663,21 @@ static void populateBuilderArgsSuccessors(
   }
 }
 
+/// Generates Python code for the default value of the given attribute.
+static FailureOr<std::string> getAttributeDefaultValue(Attribute attr) {
+  assert(attr.hasDefaultValue() && "expected attribute with default value");
+  StringRef storageType = attr.getStorageType().trim();
+  StringRef defaultValCpp = attr.getDefaultValue().trim();
+
+  // A list of commonly used attribute types and default values for which
+  // we can generate Python code. Extend as needed.
+  if (storageType.equals("::mlir::ArrayAttr") && defaultValCpp.equals("{}"))
+    return std::string("_ods_ir.ArrayAttr.get([])");
+
+  // No match: Cannot generate Python code.
+  return failure();
+}
+
 /// Populates `builderLines` with additional lines that are required in the
 /// builder to set up operation attributes. `argNames` is expected to contain
 /// the names of builder arguments that correspond to op arguments, i.e. to the
@@ -669,6 +700,25 @@ populateBuilderLinesAttr(const Operator &op,
       continue;
     }
 
+    // Attributes with default value are handled specially.
+    if (attribute->attr.hasDefaultValue()) {
+      // In case we cannot generate Python code for the default value, the
+      // attribute must be specified by the user.
+      FailureOr<std::string> defaultValPy =
+          getAttributeDefaultValue(attribute->attr);
+      if (succeeded(defaultValPy)) {
+        builderLines.push_back(llvm::formatv(initDefaultValuedAttributeTemplate,
+                                             attribute->name, argNames[i],
+                                             *defaultValPy));
+      } else {
+        builderLines.push_back(llvm::formatv(assertAttributeValueSpecified,
+                                             attribute->name, argNames[i]));
+        builderLines.push_back(
+            llvm::formatv(initAttributeTemplate, attribute->name, argNames[i]));
+      }
+      continue;
+    }
+
     builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
                                              ? initOptionalAttributeTemplate
                                              : initAttributeTemplate,


        


More information about the Mlir-commits mailing list