[Mlir-commits] [mlir] 3137c29 - Add initial python bindings for attributes.

Stella Laurenzo llvmlistbot at llvm.org
Sun Aug 23 22:18:35 PDT 2020


Author: Stella Laurenzo
Date: 2020-08-23T22:16:23-07:00
New Revision: 3137c299269dd758c4c1630dc0c4621a1137eb7c

URL: https://github.com/llvm/llvm-project/commit/3137c299269dd758c4c1630dc0c4621a1137eb7c
DIFF: https://github.com/llvm/llvm-project/commit/3137c299269dd758c4c1630dc0c4621a1137eb7c.diff

LOG: Add initial python bindings for attributes.

* Generic mlir.ir.Attribute class.
* First standard attribute (mlir.ir.StringAttr), following the same pattern as generic vs standard types.
* NamedAttribute class.

Differential Revision: https://reviews.llvm.org/D86250

Added: 
    mlir/test/Bindings/Python/ir_attributes.py

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/Bindings/Python/PybindUtils.cpp
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 9293a40ebbab..225b53166306 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -336,6 +336,9 @@ void mlirTypeDump(MlirType type);
 /** Parses an attribute. The attribute is owned by the context. */
 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
 
+/** Checks whether an attribute is null. */
+inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
+
 /** Checks if two attributes are equal. */
 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2);
 

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 188fdf39ff14..ae48e33d3530 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -9,6 +9,7 @@
 #include "IRModules.h"
 #include "PybindUtils.h"
 
+#include "mlir-c/StandardAttributes.h"
 #include "mlir-c/StandardTypes.h"
 
 namespace py = pybind11;
@@ -76,8 +77,52 @@ struct PyPrintAccumulator {
   }
 };
 
+/// Accumulates into a python string from a method that is expected to make
+/// one (no more, no less) call to the callback (asserts internally on
+/// violation).
+struct PySinglePartStringAccumulator {
+  void *getUserData() { return this; }
+
+  MlirStringCallback getCallback() {
+    return [](const char *part, intptr_t size, void *userData) {
+      PySinglePartStringAccumulator *accum =
+          static_cast<PySinglePartStringAccumulator *>(userData);
+      assert(!accum->invoked &&
+             "PySinglePartStringAccumulator called back multiple times");
+      accum->invoked = true;
+      accum->value = py::str(part, size);
+    };
+  }
+
+  py::str takeValue() {
+    assert(invoked && "PySinglePartStringAccumulator not called back");
+    return std::move(value);
+  }
+
+private:
+  py::str value;
+  bool invoked = false;
+};
+
 } // namespace
 
+//------------------------------------------------------------------------------
+// PyAttribute.
+//------------------------------------------------------------------------------
+
+bool PyAttribute::operator==(const PyAttribute &other) {
+  return mlirAttributeEqual(attr, other.attr);
+}
+
+//------------------------------------------------------------------------------
+// PyNamedAttribute.
+//------------------------------------------------------------------------------
+
+PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
+    : ownedName(new std::string(std::move(ownedName))) {
+  namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
+}
+
 //------------------------------------------------------------------------------
 // PyType.
 //------------------------------------------------------------------------------
@@ -86,6 +131,86 @@ bool PyType::operator==(const PyType &other) {
   return mlirTypeEqual(type, other.type);
 }
 
+//------------------------------------------------------------------------------
+// Standard attribute subclasses.
+//------------------------------------------------------------------------------
+
+namespace {
+
+/// CRTP base classes for Python attributes that subclass Attribute and should
+/// be castable from it (i.e. via something like StringAttr(attr)).
+template <typename T>
+class PyConcreteAttribute : public PyAttribute {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  using ClassTy = py::class_<T, PyAttribute>;
+  using IsAFunctionTy = int (*)(MlirAttribute);
+
+  PyConcreteAttribute() = default;
+  PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {}
+  PyConcreteAttribute(PyAttribute &orig)
+      : PyConcreteAttribute(castFrom(orig)) {}
+
+  static MlirAttribute castFrom(PyAttribute &orig) {
+    if (!T::isaFunction(orig.attr)) {
+      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError,
+                       llvm::Twine("Cannot cast attribute to ") +
+                           T::pyClassName + " (from " + origRepr + ")");
+    }
+    return orig.attr;
+  }
+
+  static void bind(py::module &m) {
+    auto cls = ClassTy(m, T::pyClassName);
+    cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
+    T::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
+  static constexpr const char *pyClassName = "StringAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyMlirContext &context, std::string value) {
+          MlirAttribute attr =
+              mlirStringAttrGet(context.context, value.size(), &value[0]);
+          return PyStringAttribute(attr);
+        },
+        py::keep_alive<0, 1>(), "Gets a uniqued string attribute");
+    c.def_static(
+        "get_typed",
+        [](PyType &type, std::string value) {
+          MlirAttribute attr =
+              mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
+          return PyStringAttribute(attr);
+        },
+        py::keep_alive<0, 1>(),
+        "Gets a uniqued string attribute associated to a type");
+    c.def_property_readonly(
+        "value",
+        [](PyStringAttribute &self) {
+          PySinglePartStringAccumulator accum;
+          mlirStringAttrGetValue(self.attr, accum.getCallback(),
+                                 accum.getUserData());
+          return accum.takeValue();
+        },
+        "Returns the value of the string attribute");
+  }
+};
+
+} // namespace
+
 //------------------------------------------------------------------------------
 // Standard type subclasses.
 //------------------------------------------------------------------------------
@@ -118,9 +243,9 @@ class PyConcreteType : public PyType {
   }
 
   static void bind(py::module &m) {
-    auto class_ = ClassTy(m, T::pyClassName);
-    class_.def(py::init<PyType &>(), py::keep_alive<0, 1>());
-    T::bindDerived(class_);
+    auto cls = ClassTy(m, T::pyClassName);
+    cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
+    T::bindDerived(cls);
   }
 
   /// Implemented by derived classes to add methods to the Python subclass.
@@ -135,21 +260,21 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "signless",
+        "get_signless",
         [](PyMlirContext &context, unsigned width) {
           MlirType t = mlirIntegerTypeGet(context.context, width);
           return PyIntegerType(t);
         },
         py::keep_alive<0, 1>(), "Create a signless integer type");
     c.def_static(
-        "signed",
+        "get_signed",
         [](PyMlirContext &context, unsigned width) {
           MlirType t = mlirIntegerTypeSignedGet(context.context, width);
           return PyIntegerType(t);
         },
         py::keep_alive<0, 1>(), "Create a signed integer type");
     c.def_static(
-        "unsigned",
+        "get_unsigned",
         [](PyMlirContext &context, unsigned width) {
           MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
           return PyIntegerType(t);
@@ -195,6 +320,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           [](PyMlirContext &self, const std::string module) {
             auto moduleRef =
                 mlirModuleCreateParse(self.context, module.c_str());
+            // TODO: Rework error reporting once diagnostic engine is exposed
+            // in C API.
             if (mlirModuleIsNull(moduleRef)) {
               throw SetPyError(
                   PyExc_ValueError,
@@ -203,10 +330,27 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             return PyModule(moduleRef);
           },
           py::keep_alive<0, 1>(), kContextParseDocstring)
+      .def(
+          "parse_attr",
+          [](PyMlirContext &self, std::string attrSpec) {
+            MlirAttribute type =
+                mlirAttributeParseGet(self.context, attrSpec.c_str());
+            // TODO: Rework error reporting once diagnostic engine is exposed
+            // in C API.
+            if (mlirAttributeIsNull(type)) {
+              throw SetPyError(PyExc_ValueError,
+                               llvm::Twine("Unable to parse attribute: '") +
+                                   attrSpec + "'");
+            }
+            return PyAttribute(type);
+          },
+          py::keep_alive<0, 1>())
       .def(
           "parse_type",
           [](PyMlirContext &self, std::string typeSpec) {
             MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
+            // TODO: Rework error reporting once diagnostic engine is exposed
+            // in C API.
             if (mlirTypeIsNull(type)) {
               throw SetPyError(PyExc_ValueError,
                                llvm::Twine("Unable to parse type: '") +
@@ -235,6 +379,79 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           },
           kOperationStrDunderDocstring);
 
+  // Mapping of Type.
+  py::class_<PyAttribute>(m, "Attribute")
+      .def(
+          "get_named",
+          [](PyAttribute &self, std::string name) {
+            return PyNamedAttribute(self.attr, std::move(name));
+          },
+          py::keep_alive<0, 1>(), "Binds a name to the attribute")
+      .def("__eq__",
+           [](PyAttribute &self, py::object &other) {
+             try {
+               PyAttribute otherAttribute = other.cast<PyAttribute>();
+               return self == otherAttribute;
+             } catch (std::exception &e) {
+               return false;
+             }
+           })
+      .def(
+          "dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
+          kDumpDocstring)
+      .def(
+          "__str__",
+          [](PyAttribute &self) {
+            PyPrintAccumulator printAccum;
+            mlirAttributePrint(self.attr, printAccum.getCallback(),
+                               printAccum.getUserData());
+            return printAccum.join();
+          },
+          kTypeStrDunderDocstring)
+      .def("__repr__", [](PyAttribute &self) {
+        // Generally, assembly formats are not printed for __repr__ because
+        // this can cause exceptionally long debug output and exceptions.
+        // However, attribute values are generally considered useful and are
+        // printed. This may need to be re-evaluated if debug dumps end up
+        // being excessive.
+        PyPrintAccumulator printAccum;
+        printAccum.parts.append("Attribute(");
+        mlirAttributePrint(self.attr, printAccum.getCallback(),
+                           printAccum.getUserData());
+        printAccum.parts.append(")");
+        return printAccum.join();
+      });
+
+  py::class_<PyNamedAttribute>(m, "NamedAttribute")
+      .def("__repr__",
+           [](PyNamedAttribute &self) {
+             PyPrintAccumulator printAccum;
+             printAccum.parts.append("NamedAttribute(");
+             printAccum.parts.append(self.namedAttr.name);
+             printAccum.parts.append("=");
+             mlirAttributePrint(self.namedAttr.attribute,
+                                printAccum.getCallback(),
+                                printAccum.getUserData());
+             printAccum.parts.append(")");
+             return printAccum.join();
+           })
+      .def_property_readonly(
+          "name",
+          [](PyNamedAttribute &self) {
+            return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
+          },
+          "The name of the NamedAttribute binding")
+      .def_property_readonly(
+          "attr",
+          [](PyNamedAttribute &self) {
+            return PyAttribute(self.namedAttr.attribute);
+          },
+          py::keep_alive<0, 1>(),
+          "The underlying generic attribute of the NamedAttribute binding");
+
+  // Standard attribute bindings.
+  PyStringAttribute::bind(m);
+
   // Mapping of Type.
   py::class_<PyType>(m, "Type")
       .def("__eq__",

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 4e90a9ae9795..1edfc1cead3e 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -45,6 +45,39 @@ class PyModule {
   MlirModule module;
 };
 
+/// Wrapper around the generic MlirAttribute.
+/// The lifetime of a type is bound by the PyContext that created it.
+class PyAttribute {
+public:
+  PyAttribute(MlirAttribute attr) : attr(attr) {}
+  bool operator==(const PyAttribute &other);
+
+  MlirAttribute attr;
+};
+
+/// Represents a Python MlirNamedAttr, carrying an optional owned name.
+/// TODO: Refactor this and the C-API to be based on an Identifier owned
+/// by the context so as to avoid ownership issues here.
+class PyNamedAttribute {
+public:
+  /// Constructs a PyNamedAttr that retains an owned name. This should be
+  /// used in any code that originates an MlirNamedAttribute from a python
+  /// string.
+  /// The lifetime of the PyNamedAttr must extend to the lifetime of the
+  /// passed attribute.
+  PyNamedAttribute(MlirAttribute attr, std::string ownedName);
+
+  MlirNamedAttribute namedAttr;
+
+private:
+  // Since the MlirNamedAttr contains an internal pointer to the actual
+  // memory of the owned string, it must be heap allocated to remain valid.
+  // Otherwise, strings that fit within the small object optimization threshold
+  // will have their memory address change as the containing object is moved,
+  // resulting in an invalid aliased pointer.
+  std::unique_ptr<std::string> ownedName;
+};
+
 /// Wrapper around the generic MlirType.
 /// The lifetime of a type is bound by the PyContext that created it.
 class PyType {

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp
index 9013c0669794..bd80b8c14702 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.cpp
+++ b/mlir/lib/Bindings/Python/PybindUtils.cpp
@@ -10,8 +10,8 @@
 
 namespace py = pybind11;
 
-pybind11::error_already_set mlir::python::SetPyError(PyObject *excClass,
-                                                     llvm::Twine message) {
+pybind11::error_already_set
+mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) {
   auto messageStr = message.str();
   PyErr_SetString(excClass, messageStr.c_str());
   return pybind11::error_already_set();

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 1a82f8e824ec..0c0e069fae03 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -20,7 +20,8 @@ namespace python {
 // python runtime.
 // Correct usage:
 //   throw SetPyError(PyExc_ValueError, "Foobar'd");
-pybind11::error_already_set SetPyError(PyObject *excClass, llvm::Twine message);
+pybind11::error_already_set SetPyError(PyObject *excClass,
+                                       const llvm::Twine &message);
 
 } // namespace python
 } // namespace mlir

diff  --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
new file mode 100644
index 000000000000..328dfb40b972
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -0,0 +1,119 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import mlir
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+
+
+# CHECK-LABEL: TEST: testParsePrint
+def testParsePrint():
+  ctx = mlir.ir.Context()
+  t = ctx.parse_attr('"hello"')
+  # CHECK: "hello"
+  print(str(t))
+  # CHECK: Attribute("hello")
+  print(repr(t))
+
+run(testParsePrint)
+
+
+# CHECK-LABEL: TEST: testParseError
+# TODO: Hook the diagnostic manager to capture a more meaningful error
+# message.
+def testParseError():
+  ctx = mlir.ir.Context()
+  try:
+    t = ctx.parse_attr("BAD_ATTR_DOES_NOT_EXIST")
+  except ValueError as e:
+    # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
+    print("testParseError:", e)
+  else:
+    print("Exception not produced")
+
+run(testParseError)
+
+
+# CHECK-LABEL: TEST: testAttrEq
+def testAttrEq():
+  ctx = mlir.ir.Context()
+  a1 = ctx.parse_attr('"attr1"')
+  a2 = ctx.parse_attr('"attr2"')
+  a3 = ctx.parse_attr('"attr1"')
+  # CHECK: a1 == a1: True
+  print("a1 == a1:", a1 == a1)
+  # CHECK: a1 == a2: False
+  print("a1 == a2:", a1 == a2)
+  # CHECK: a1 == a3: True
+  print("a1 == a3:", a1 == a3)
+  # CHECK: a1 == None: False
+  print("a1 == None:", a1 == None)
+
+run(testAttrEq)
+
+
+# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
+def testAttrEqDoesNotRaise():
+  ctx = mlir.ir.Context()
+  a1 = ctx.parse_attr('"attr1"')
+  not_an_attr = "foo"
+  # CHECK: False
+  print(a1 == not_an_attr)
+  # CHECK: False
+  print(a1 == None)
+  # CHECK: True
+  print(a1 != None)
+
+run(testAttrEqDoesNotRaise)
+
+
+# CHECK-LABEL: TEST: testStandardAttrCasts
+def testStandardAttrCasts():
+  ctx = mlir.ir.Context()
+  a1 = ctx.parse_attr('"attr1"')
+  astr = mlir.ir.StringAttr(a1)
+  aself = mlir.ir.StringAttr(astr)
+  # CHECK: Attribute("attr1")
+  print(repr(astr))
+  try:
+    tillegal = mlir.ir.StringAttr(ctx.parse_attr("1.0"))
+  except ValueError as e:
+    # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
+    print("ValueError:", e)
+  else:
+    print("Exception not produced")
+
+run(testStandardAttrCasts)
+
+
+# CHECK-LABEL: TEST: testStringAttr
+def testStringAttr():
+  ctx = mlir.ir.Context()
+  sattr = mlir.ir.StringAttr(ctx.parse_attr('"stringattr"'))
+  # CHECK: sattr value: stringattr
+  print("sattr value:", sattr.value)
+
+  # Test factory methods.
+  # CHECK: default_get: "foobar"
+  print("default_get:", mlir.ir.StringAttr.get(ctx, "foobar"))
+  # CHECK: typed_get: "12345" : i32
+  print("typed_get:", mlir.ir.StringAttr.get_typed(
+      mlir.ir.IntegerType.get_signless(ctx, 32), "12345"))
+
+run(testStringAttr)
+
+
+# CHECK-LABEL: TEST: testNamedAttr
+def testNamedAttr():
+  ctx = mlir.ir.Context()
+  a = ctx.parse_attr('"stringattr"')
+  named = a.get_named("foobar")  # Note: under the small object threshold
+  # CHECK: attr: "stringattr"
+  print("attr:", named.attr)
+  # CHECK: name: foobar
+  print("name:", named.name)
+  # CHECK: named: NamedAttribute(foobar="stringattr")
+  print("named:", named)
+
+run(testNamedAttr)

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index cc66b1fdb208..1dce0a95c812 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -117,10 +117,10 @@ def testIntegerType():
   print("u32 unsigned:", u32.is_unsigned)
 
   # CHECK: signless: i16
-  print("signless:", mlir.ir.IntegerType.signless(ctx, 16))
+  print("signless:", mlir.ir.IntegerType.get_signless(ctx, 16))
   # CHECK: signed: si8
-  print("signed:", mlir.ir.IntegerType.signed(ctx, 8))
+  print("signed:", mlir.ir.IntegerType.get_signed(ctx, 8))
   # CHECK: unsigned: ui64
-  print("unsigned:", mlir.ir.IntegerType.unsigned(ctx, 64))
+  print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64))
 
 run(testIntegerType)


        


More information about the Mlir-commits mailing list