[Mlir-commits] [mlir] c5a6712 - [mlir] Add basic support for attributes in ODS-generated Python bindings

Alex Zinenko llvmlistbot at llvm.org
Tue Nov 17 02:47:46 PST 2020


Author: Alex Zinenko
Date: 2020-11-17T11:47:37+01:00
New Revision: c5a6712f8cee1bfbb5d730531943558645853c0b

URL: https://github.com/llvm/llvm-project/commit/c5a6712f8cee1bfbb5d730531943558645853c0b
DIFF: https://github.com/llvm/llvm-project/commit/c5a6712f8cee1bfbb5d730531943558645853c0b.diff

LOG: [mlir] Add basic support for attributes in ODS-generated Python bindings

In ODS, attributes of an operation can be provided as a part of the "arguments"
field, together with operands. Such attributes are accepted by the op builder
and have accessors generated.

Implement similar functionality for ODS-generated op-specific Python bindings:
the `__init__` method now accepts arguments together with operands, in the same
order as in the ODS `arguments` field; the instance properties are introduced
to OpView classes to access the attributes.

This initial implementation accepts and returns instances of the corresponding
attribute class, and not the underlying values since the mapping scheme of the
value types between C++, C and Python is not yet clear. Default-valued
attributes are not supported as that would require Python to be able to parse
C++ literals.

Since attributes in ODS are tightely related to the actual C++ type system,
provide a separate Tablegen file with the mapping between ODS storage type for
attributes (typically, the underlying C++ attribute class), and the
corresponding class name. So far, this might look unnecessary since all names
match exactly, but this is not necessarily the cases for non-standard,
out-of-tree attributes, which may also be placed in non-default namespaces or
Python modules. This also allows out-of-tree users to generate Python bindings
without having to modify the bindings generator itself. Storage type was
preferred over the Tablegen "def" of the attribute class because ODS
essentially encodes attribute _constraints_ rather than classes, e.g. there may
be many Tablegen "def"s in the ODS that correspond to the same attribute type
with additional constraints

The presence of the explicit mapping requires the change in the .td file
structure: instead of just calling the bindings generator directly on the main
ODS file of the dialect, it becomes necessary to create a new file that
includes the main ODS file of the dialect and provides the mapping for
attribute types. Arguably, this approach offers better separability of the
Python bindings in the build system as the main dialect no longer needs to know
that it is being processed by the bindings generator.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D91542

Added: 
    mlir/lib/Bindings/Python/Attributes.td
    mlir/lib/Bindings/Python/StandardOps.td

Modified: 
    mlir/CMakeLists.txt
    mlir/cmake/modules/AddMLIRPythonExtension.cmake
    mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
    mlir/lib/Bindings/Python/CMakeLists.txt
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/TableGen/Operator.cpp
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 8c8d59280968..09ab67ff73dc 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -100,12 +100,6 @@ include_directories( ${MLIR_INCLUDE_DIR})
 # from another directory like tools
 add_subdirectory(tools/mlir-tblgen)
 
-# Create an anchor target that will depend on dialect-specific op bindings.
-if (MLIR_BINDINGS_PYTHON_ENABLED)
-  add_custom_target(MLIRBindingsPythonIncGen)
-  include(AddMLIRPythonExtension)
-endif()
-
 add_subdirectory(include/mlir)
 add_subdirectory(lib)
 # C API needs all dialects for registration, but should be built before tests.

diff  --git a/mlir/cmake/modules/AddMLIRPythonExtension.cmake b/mlir/cmake/modules/AddMLIRPythonExtension.cmake
index 43ad869a400b..eaba5214b5a8 100644
--- a/mlir/cmake/modules/AddMLIRPythonExtension.cmake
+++ b/mlir/cmake/modules/AddMLIRPythonExtension.cmake
@@ -132,16 +132,10 @@ function(add_mlir_python_extension libname extname)
 
 endfunction()
 
-function(add_mlir_dialect_python_bindings filename dialectname)
+function(add_mlir_dialect_python_bindings tblgen_target filename dialectname)
   set(LLVM_TARGET_DEFINITIONS ${filename})
   mlir_tablegen("${dialectname}.py" -gen-python-op-bindings
                 -bind-dialect=${dialectname})
-  if (${ARGC} GREATER 2)
-    set(suffix ${ARGV2})
-  else()
-    get_filename_component(suffix ${filename} NAME_WE)
-  endif()
-  set(tblgen_target "MLIRBindingsPython${suffix}")
   add_public_tablegen_target(${tblgen_target})
 
   add_custom_command(
@@ -150,6 +144,5 @@ function(add_mlir_dialect_python_bindings filename dialectname)
     COMMAND "${CMAKE_COMMAND}" -E copy_if_
diff erent
       "${CMAKE_CURRENT_BINARY_DIR}/${dialectname}.py"
       "${PROJECT_BINARY_DIR}/python/mlir/dialects/${dialectname}.py")
-  add_dependencies(MLIRBindingsPythonIncGen ${tblgen_target})
 endfunction()
 

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
index ee3e3cfdd9f2..b9178c5a0db3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
@@ -7,7 +7,3 @@ mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRStandardOpsIncGen)
 
 add_mlir_doc(Ops -gen-op-doc StandardOps Dialects/)
-
-if (MLIR_BINDINGS_PYTHON_ENABLED)
-  add_mlir_dialect_python_bindings(Ops.td std StandardOps)
-endif()

diff  --git a/mlir/lib/Bindings/Python/Attributes.td b/mlir/lib/Bindings/Python/Attributes.td
new file mode 100644
index 000000000000..0ed155035a99
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Attributes.td
@@ -0,0 +1,34 @@
+//===-- Attributes.td - Attribute mapping for Python -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This defines the mapping between MLIR ODS attributes and the corresponding
+// Python binding classes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_ATTRIBUTES
+#define PYTHON_BINDINGS_ATTRIBUTES
+
+// A mapping between the attribute storage type and the corresponding Python
+// type. There is not necessarily a 1-1 match for non-standard attributes.
+class PythonAttr<string c, string p> {
+  string cppStorageType = c;
+  string pythonType = p;
+}
+
+// Mappings between supported standard attribtues and Python types.
+def : PythonAttr<"::mlir::Attribute", "_ir.Attribute">;
+def : PythonAttr<"::mlir::BoolAttr", "_ir.BoolAttr">;
+def : PythonAttr<"::mlir::IntegerAttr", "_ir.IntegerAttr">;
+def : PythonAttr<"::mlir::FloatAttr", "_ir.FloatAttr">;
+def : PythonAttr<"::mlir::StringAttr", "_ir.StringAttr">;
+def : PythonAttr<"::mlir::DenseElementsAttr", "_ir.DenseElementsAttr">;
+def : PythonAttr<"::mlir::DenseIntElementsAttr", "_ir.DenseIntElementsAttr">;
+def : PythonAttr<"::mlir::DenseFPElementsAttr", "_ir.DenseFPElementsAttr">;
+
+#endif

diff  --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index 596bff23093e..0f51ce54ed09 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -1,5 +1,15 @@
 include(AddMLIRPythonExtension)
 add_custom_target(MLIRBindingsPythonExtension)
+
+################################################################################
+# Generate dialect-specific bindings.
+################################################################################
+
+add_mlir_dialect_python_bindings(MLIRBindingsPythonStandardOps
+  StandardOps.td
+  std)
+add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonStandardOps)
+
 ################################################################################
 # Copy python source tree.
 ################################################################################
@@ -19,8 +29,6 @@ add_custom_target(MLIRBindingsPythonSources ALL
 )
 add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonSources)
 
-add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonIncGen)
-
 foreach(PY_SRC_FILE ${PY_SRC_FILES})
   set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}")
   add_custom_command(

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 24b3da2b821f..152f067ea636 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1310,8 +1310,14 @@ class PyOpAttributeMap {
     return mlirOperationGetNumAttributes(operation->get());
   }
 
+  bool dunderContains(const std::string &name) {
+    return !mlirAttributeIsNull(
+        mlirOperationGetAttributeByName(operation->get(), name.c_str()));
+  }
+
   static void bind(py::module &m) {
     py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
+        .def("__contains__", &PyOpAttributeMap::dunderContains)
         .def("__len__", &PyOpAttributeMap::dunderLen)
         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed);
@@ -1747,6 +1753,24 @@ class PyDenseFPElementsAttribute
   }
 };
 
+/// Unit Attribute subclass. Unit attributes don't have values.
+class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
+  static constexpr const char *pyClassName = "UnitAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return PyUnitAttribute(context->getRef(),
+                                 mlirUnitAttrGet(context->get()));
+        },
+        py::arg("context") = py::none(), "Create a Unit attribute.");
+  }
+};
+
 } // namespace
 
 //------------------------------------------------------------------------------
@@ -2852,6 +2876,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyDenseElementsAttribute::bind(m);
   PyDenseIntElementsAttribute::bind(m);
   PyDenseFPElementsAttribute::bind(m);
+  PyUnitAttribute::bind(m);
 
   //----------------------------------------------------------------------------
   // Mapping of PyType.

diff  --git a/mlir/lib/Bindings/Python/StandardOps.td b/mlir/lib/Bindings/Python/StandardOps.td
new file mode 100644
index 000000000000..1bc7b09a4719
--- /dev/null
+++ b/mlir/lib/Bindings/Python/StandardOps.td
@@ -0,0 +1,20 @@
+//===-- StandardOps.td - Entry point for StandardOps bind --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the main file from which the Python bindings for the Standard
+// dialect are generated.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_STANDARD_OPS
+#define PYTHON_BINDINGS_STANDARD_OPS
+
+include "mlir/Dialect/StandardOps/IR/Ops.td"
+include "Attributes.td"
+
+#endif

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 605804baf468..b59ffaf8d092 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -147,7 +147,7 @@ Operator::arg_range Operator::getArgs() const {
 
 StringRef Operator::getArgName(int index) const {
   DagInit *argumentValues = def.getValueAsDag("arguments");
-  return argumentValues->getArgName(index)->getValue();
+  return argumentValues->getArgNameStr(index);
 }
 
 auto Operator::getArgDecorators(int index) const -> var_decorator_range {

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 77af112a9b0e..04d798a42641 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -1,6 +1,7 @@
-// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include %s | FileCheck %s
+// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include -I %S/../../lib/Bindings/Python %s | FileCheck %s
 
 include "mlir/IR/OpBase.td"
+include "Attributes.td"
 
 // CHECK: @_cext.register_dialect
 // CHECK: class _Dialect(_ir.Dialect):
@@ -105,6 +106,75 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
                  Optional<AnyType>:$variadic2);
 }
 
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class AttributedOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.attributed_op"
+def AttributedOp : TestOp<"attributed_op"> {
+  // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   attributes["i32attr"] = i32attr
+  // CHECK:   if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr
+  // CHECK:   if bool(unitAttr): attributes["unitAttr"] = _ir.UnitAttr.get(
+  // CHECK:     _ir.Location.current.context if loc is None else loc.context)
+  // CHECK:   attributes["in"] = in_
+  // CHECK:   super().__init__(_ir.Operation.create(
+  // CHECK:     "test.attributed_op", attributes=attributes, operands=operands, results=results,
+  // CHECK:     loc=loc, ip=ip))
+
+  // CHECK: @property
+  // CHECK: def i32attr(self):
+  // CHECK:   return _ir.IntegerAttr(self.operation.attributes["i32attr"])
+
+  // CHECK: @property
+  // CHECK: def optionalF32Attr(self):
+  // CHECK:   if "optionalF32Attr" not in self.operation.attributes:
+  // CHECK:     return None
+  // CHECK:   return _ir.FloatAttr(self.operation.attributes["optionalF32Attr"])
+
+  // CHECK: @property
+  // CHECK: def unitAttr(self):
+  // CHECK:   return "unitAttr" in self.operation.attributes
+
+  // CHECK: @property
+  // CHECK: def in_(self):
+  // CHECK:   return _ir.IntegerAttr(self.operation.attributes["in"])
+  let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
+                   UnitAttr:$unitAttr, I32Attr:$in);
+}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class AttributedOpWithOperands(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
+def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
+  // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   operands.append(_gen_arg_0)
+  // CHECK:   operands.append(_gen_arg_2)
+  // CHECK:   if bool(in_): attributes["in"] = _ir.UnitAttr.get(
+  // CHECK:     _ir.Location.current.context if loc is None else loc.context)
+  // CHECK:   if is_ is not None: attributes["is"] = is_
+  // CHECK:   super().__init__(_ir.Operation.create(
+  // CHECK:     "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results,
+  // CHECK:     loc=loc, ip=ip))
+
+  // CHECK: @property
+  // CHECK: def in_(self):
+  // CHECK:   return "in" in self.operation.attributes
+
+  // CHECK: @property
+  // CHECK: def is_(self):
+  // CHECK:   if "is" not in self.operation.attributes:
+  // CHECK:     return None
+  // CHECK:   return _ir.FloatAttr(self.operation.attributes["is"])
+  let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
+}
+
+
 // CHECK: @_cext.register_operation(_Dialect)
 // CHECK: class EmptyOp(_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.empty"

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index e32924451234..2a3ce5500133 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -145,6 +145,39 @@ constexpr const char *opVariadicSegmentTemplate = R"Py(
 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
     R"Py([0] if len({0}_range) > 0 else None)Py";
 
+/// Template for an operation attribute getter:
+///   {0} is the name of the attribute sanitized for Python;
+///   {1} is the Python type of the attribute;
+///   {2} os the original name of the attribute.
+constexpr const char *attributeGetterTemplate = R"Py(
+  @property
+  def {0}(self):
+    return {1}(self.operation.attributes["{2}"])
+)Py";
+
+/// Template for an optional operation attribute getter:
+///   {0} is the name of the attribute sanitized for Python;
+///   {1} is the Python type of the attribute;
+///   {2} is the original name of the attribute.
+constexpr const char *optionalAttributeGetterTemplate = R"Py(
+  @property
+  def {0}(self):
+    if "{2}" not in self.operation.attributes:
+      return None
+    return {1}(self.operation.attributes["{2}"])
+)Py";
+
+/// Template for a accessing a unit operation attribute, returns True of the
+/// unit attribute is present, False otherwise (unit attributes have meaning
+/// by mere presence):
+///    {0} is the name of the attribute sanitized for Python,
+///    {1} is the original name of the attribute.
+constexpr const char *unitAttributeGetterTemplate = R"Py(
+  @property
+  def {0}(self):
+    return "{1}" in self.operation.attributes
+)Py";
+
 static llvm::cl::OptionCategory
     clOpPythonBindingCat("Options for -gen-python-op-bindings");
 
@@ -153,6 +186,8 @@ static llvm::cl::opt<std::string>
                   llvm::cl::desc("The dialect to run the generator for"),
                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
 
+using AttributeClasses = DenseMap<StringRef, StringRef>;
+
 /// Checks whether `str` is a Python keyword.
 static bool isPythonKeyword(StringRef str) {
   static llvm::StringSet<> keywords(
@@ -285,7 +320,7 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) {
   return op.getResult(i);
 }
 
-/// Emits accessor to Op operands.
+/// Emits accessors to Op operands.
 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
   auto getNumVariadic = [](const Operator &oper) {
     return oper.getNumVariableLengthOperands();
@@ -294,7 +329,7 @@ static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
                        getOperand);
 }
 
-/// Emits access or Op results.
+/// Emits accessors Op results.
 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
   auto getNumVariadic = [](const Operator &oper) {
     return oper.getNumVariableLengthResults();
@@ -303,6 +338,39 @@ static void emitResultAccessors(const Operator &op, raw_ostream &os) {
                        getResult);
 }
 
+/// Emits accessors to Op attributes.
+static void emitAttributeAccessors(const Operator &op,
+                                   const AttributeClasses &attributeClasses,
+                                   raw_ostream &os) {
+  for (const auto &namedAttr : op.getAttributes()) {
+    // Skip "derived" attributes because they are just C++ functions that we
+    // don't currently expose.
+    if (namedAttr.attr.isDerivedAttr())
+      continue;
+
+    if (namedAttr.name.empty())
+      continue;
+
+    // Unit attributes are handled specially.
+    if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
+      os << llvm::formatv(unitAttributeGetterTemplate,
+                          sanitizeName(namedAttr.name), namedAttr.name);
+      continue;
+    }
+
+    // Other kinds of attributes need a mapping to a Python type.
+    if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
+      continue;
+
+    os << llvm::formatv(
+        namedAttr.attr.isOptional() ? optionalAttributeGetterTemplate
+                                    : attributeGetterTemplate,
+        sanitizeName(namedAttr.name),
+        attributeClasses.lookup(namedAttr.attr.getStorageType()),
+        namedAttr.name);
+  }
+}
+
 /// Template for the default auto-generated builder.
 ///   {0} is the operation name;
 ///   {1} is a comma-separated list of builder arguments, including the trailing
@@ -362,14 +430,82 @@ constexpr const char *optionalSegmentTemplate =
 constexpr const char *variadicSegmentTemplate =
     "{0}_segment_sizes.append(len({1}))";
 
-/// Populates `builderArgs` with the list of `__init__` arguments that
-/// correspond to either operands or results of `op`, and `builderLines` with
-/// additional lines that are required in the builder. `kind` must be either
-/// "operand" or "result". `unnamedTemplate` is used to generate names for
-/// operands or results that don't have the name in ODS.
+/// Template for setting an attribute in the operation builder.
+///   {0} is the attribute name;
+///   {1} is the builder argument name.
+constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
+
+/// Template for setting an optional attribute in the operation builder.
+///   {0} is the attribute name;
+///   {1} is the builder argument name.
+constexpr const char *initOptionalAttributeTemplate =
+    R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
+
+constexpr const char *initUnitAttributeTemplate =
+    R"Py(if bool({1}): attributes["{0}"] = _ir.UnitAttr.get(
+      _ir.Location.current.context if loc is None else loc.context))Py";
+
+/// Populates `builderArgs` with the Python-compatible names of builder function
+/// arguments, first the results, then the intermixed attributes and operands in
+/// the same order as they appear in the `arguments` field of the op definition.
+/// Additionally, `operandNames` is populated with names of operands in their
+/// order of appearance.
+static void
+populateBuilderArgs(const Operator &op,
+                    llvm::SmallVectorImpl<std::string> &builderArgs,
+                    llvm::SmallVectorImpl<std::string> &operandNames) {
+  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+    std::string name = op.getResultName(i).str();
+    if (name.empty())
+      name = llvm::formatv("_gen_res_{0}", i);
+    name = sanitizeName(name);
+    builderArgs.push_back(name);
+  }
+  for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
+    std::string name = op.getArgName(i).str();
+    if (name.empty())
+      name = llvm::formatv("_gen_arg_{0}", i);
+    name = sanitizeName(name);
+    builderArgs.push_back(name);
+    if (!op.getArg(i).is<NamedAttribute *>())
+      operandNames.push_back(name);
+  }
+}
+
+/// Populates `builderLines` with additional lines that are required in the
+/// builder to set up operation attributes. `argNames` is expected to contain
+/// the names of builder arguments that correspond to op arguments, i.e. to the
+/// operands and attributes in the same order as they appear in the `arguments`
+/// field.
+static void
+populateBuilderLinesAttr(const Operator &op,
+                         llvm::ArrayRef<std::string> argNames,
+                         llvm::SmallVectorImpl<std::string> &builderLines) {
+  for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
+    Argument arg = op.getArg(i);
+    auto *attribute = arg.dyn_cast<NamedAttribute *>();
+    if (!attribute)
+      continue;
+
+    // Unit attributes are handled specially.
+    if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
+      builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
+                                           attribute->name, argNames[i]));
+      continue;
+    }
+
+    builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
+                                             ? initOptionalAttributeTemplate
+                                             : initAttributeTemplate,
+                                         attribute->name, argNames[i]));
+  }
+}
+
+/// Populates `builderLines` with additional lines that are required in the
+/// builder. `kind` must be either "operand" or "result". `names` contains the
+/// names of init arguments that correspond to the elements.
 static void populateBuilderLines(
-    const Operator &op, const char *kind, const char *unnamedTemplate,
-    llvm::SmallVectorImpl<std::string> &builderArgs,
+    const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
     llvm::SmallVectorImpl<std::string> &builderLines,
     llvm::function_ref<int(const Operator &)> getNumElements,
     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
@@ -383,11 +519,7 @@ static void populateBuilderLines(
   // For each element, find or generate a name.
   for (int i = 0, e = getNumElements(op); i < e; ++i) {
     const NamedTypeConstraint &element = getElement(op, i);
-    std::string name = element.name.str();
-    if (name.empty())
-      name = llvm::formatv(unnamedTemplate, i).str();
-    name = sanitizeName(name);
-    builderArgs.push_back(name);
+    std::string name = names[i];
 
     // Choose the formatting string based on the element kind.
     llvm::StringRef formatString, segmentFormatString;
@@ -417,21 +549,25 @@ static void populateBuilderLines(
 /// Emits a default builder constructing an operation from the list of its
 /// result types, followed by a list of its operands.
 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
-  // TODO: support attribute types.
-  if (op.getNumNativeAttributes() != 0)
-    return;
-
   // If we are asked to skip default builders, comply.
   if (op.skipDefaultBuilders())
     return;
 
   llvm::SmallVector<std::string, 8> builderArgs;
   llvm::SmallVector<std::string, 8> builderLines;
-  builderArgs.reserve(op.getNumOperands() + op.getNumResults());
-  populateBuilderLines(op, "result", "_gen_res_{0}", builderArgs, builderLines,
-                       getNumResults, getResult);
-  populateBuilderLines(op, "operand", "_gen_arg_{0}", builderArgs, builderLines,
+  llvm::SmallVector<std::string, 4> operandArgNames;
+  builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
+                      op.getNumNativeAttributes());
+  populateBuilderArgs(op, builderArgs, operandArgNames);
+  populateBuilderLines(
+      op, "result",
+      llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
+      builderLines, getNumResults, getResult);
+  populateBuilderLines(op, "operand", operandArgNames, builderLines,
                        getNumOperands, getOperand);
+  populateBuilderLinesAttr(
+      op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
+      builderLines);
 
   builderArgs.push_back("loc=None");
   builderArgs.push_back("ip=None");
@@ -440,12 +576,24 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
                       llvm::join(builderLines, "\n    "));
 }
 
+static void constructAttributeMapping(const llvm::RecordKeeper &records,
+                                      AttributeClasses &attributeClasses) {
+  for (const llvm::Record *rec :
+       records.getAllDerivedDefinitions("PythonAttr")) {
+    attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
+                                 rec->getValueAsString("pythonType").trim());
+  }
+}
+
 /// Emits bindings for a specific Op to the given output stream.
-static void emitOpBindings(const Operator &op, raw_ostream &os) {
+static void emitOpBindings(const Operator &op,
+                           const AttributeClasses &attributeClasses,
+                           raw_ostream &os) {
   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
                       op.getOperationName());
   emitDefaultOpBuilder(op, os);
   emitOperandAccessors(op, os);
+  emitAttributeAccessors(op, attributeClasses, os);
   emitResultAccessors(op, os);
 }
 
@@ -456,12 +604,15 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
   if (clDialectName.empty())
     llvm::PrintFatalError("dialect name not provided");
 
+  AttributeClasses attributeClasses;
+  constructAttributeMapping(records, attributeClasses);
+
   os << fileHeader;
   os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
     Operator op(rec);
     if (op.getDialectName() == clDialectName.getValue())
-      emitOpBindings(op, os);
+      emitOpBindings(op, attributeClasses, os);
   }
   return false;
 }


        


More information about the Mlir-commits mailing list