[Mlir-commits] [mlir] 5f0c1e3 - [mlir][Python] Sync Python bindings with C API MlirStringRef modification.

Stella Laurenzo llvmlistbot at llvm.org
Tue Nov 24 11:35:36 PST 2020


Author: zhanghb97
Date: 2020-11-24T19:33:48Z
New Revision: 5f0c1e380661f5a28fb8e87d70a68fa31e923436

URL: https://github.com/llvm/llvm-project/commit/5f0c1e380661f5a28fb8e87d70a68fa31e923436
DIFF: https://github.com/llvm/llvm-project/commit/5f0c1e380661f5a28fb8e87d70a68fa31e923436.diff

LOG: [mlir][Python] Sync Python bindings with C API MlirStringRef modification.

MLIR C API use the `MlirStringRef` instead of `const char *` for the string type now. This patch sync the Python bindings with the C API modification.

Differential Revision: https://reviews.llvm.org/D92007

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/PybindUtils.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 1e848c2d1531..e145a58d0d27 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -145,6 +145,11 @@ createCustomDialectWrapper(const std::string &dialectNamespace,
   // Create the custom implementation.
   return (*dialectClass)(std::move(dialectDescriptor));
 }
+
+static MlirStringRef toMlirStringRef(const std::string &s) {
+  return mlirStringRefCreate(s.data(), s.size());
+}
+
 //------------------------------------------------------------------------------
 // Collections.
 //------------------------------------------------------------------------------
@@ -902,7 +907,8 @@ py::object PyOperation::create(
 
   // Apply unpacked/validated to the operation state. Beyond this
   // point, exceptions cannot be thrown or else the state will leak.
-  MlirOperationState state = mlirOperationStateGet(name.c_str(), location->loc);
+  MlirOperationState state =
+      mlirOperationStateGet(toMlirStringRef(name), location->loc);
   if (!mlirOperands.empty())
     mlirOperationStateAddOperands(&state, mlirOperands.size(),
                                   mlirOperands.data());
@@ -917,7 +923,7 @@ py::object PyOperation::create(
     mlirNamedAttributes.reserve(mlirAttributes.size());
     for (auto &it : mlirAttributes)
       mlirNamedAttributes.push_back(
-          mlirNamedAttributeGet(it.first.c_str(), it.second));
+          mlirNamedAttributeGet(toMlirStringRef(it.first), it.second));
     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
                                     mlirNamedAttributes.data());
   }
@@ -1076,7 +1082,7 @@ bool PyAttribute::operator==(const PyAttribute &other) {
 
 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
     : ownedName(new std::string(std::move(ownedName))) {
-  namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
+  namedAttr = mlirNamedAttributeGet(toMlirStringRef(*this->ownedName), attr);
 }
 
 //------------------------------------------------------------------------------
@@ -1287,8 +1293,8 @@ class PyOpAttributeMap {
   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
 
   PyAttribute dunderGetItemNamed(const std::string &name) {
-    MlirAttribute attr =
-        mlirOperationGetAttributeByName(operation->get(), name.c_str());
+    MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
+                                                         toMlirStringRef(name));
     if (mlirAttributeIsNull(attr)) {
       throw SetPyError(PyExc_KeyError,
                        "attempt to access a non-existent attribute");
@@ -1303,16 +1309,18 @@ class PyOpAttributeMap {
     }
     MlirNamedAttribute namedAttr =
         mlirOperationGetAttribute(operation->get(), index);
-    return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name));
+    return PyNamedAttribute(namedAttr.attribute,
+                            std::string(namedAttr.name.data));
   }
 
   void dunderSetItem(const std::string &name, PyAttribute attr) {
-    mlirOperationSetAttributeByName(operation->get(), name.c_str(), attr.attr);
+    mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
+                                    attr.attr);
   }
 
   void dunderDelItem(const std::string &name) {
-    int removed =
-        mlirOperationRemoveAttributeByName(operation->get(), name.c_str());
+    int removed = mlirOperationRemoveAttributeByName(operation->get(),
+                                                     toMlirStringRef(name));
     if (!removed)
       throw SetPyError(PyExc_KeyError,
                        "attempt to delete a non-existent attribute");
@@ -1323,8 +1331,8 @@ class PyOpAttributeMap {
   }
 
   bool dunderContains(const std::string &name) {
-    return !mlirAttributeIsNull(
-        mlirOperationGetAttributeByName(operation->get(), name.c_str()));
+    return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
+        operation->get(), toMlirStringRef(name)));
   }
 
   static void bind(py::module &m) {
@@ -2599,9 +2607,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           "file",
           [](std::string filename, int line, int col,
              DefaultingPyMlirContext context) {
-            return PyLocation(context->getRef(),
-                              mlirLocationFileLineColGet(
-                                  context->get(), filename.c_str(), line, col));
+            return PyLocation(
+                context->getRef(),
+                mlirLocationFileLineColGet(
+                    context->get(), toMlirStringRef(filename), line, col));
           },
           py::arg("filename"), py::arg("line"), py::arg("col"),
           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
@@ -2625,8 +2634,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_static(
           "parse",
           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
-            MlirModule module =
-                mlirModuleCreateParse(context->get(), moduleAsm.c_str());
+            MlirModule module = mlirModuleCreateParse(
+                context->get(), toMlirStringRef(moduleAsm));
             // TODO: Rework error reporting once diagnostic engine is exposed
             // in C API.
             if (mlirModuleIsNull(module)) {
@@ -2875,8 +2884,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_static(
           "parse",
           [](std::string attrSpec, DefaultingPyMlirContext context) {
-            MlirAttribute type =
-                mlirAttributeParseGet(context->get(), attrSpec.c_str());
+            MlirAttribute type = mlirAttributeParseGet(
+                context->get(), toMlirStringRef(attrSpec));
             // TODO: Rework error reporting once diagnostic engine is exposed
             // in C API.
             if (mlirAttributeIsNull(type)) {
@@ -2940,7 +2949,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
            [](PyNamedAttribute &self) {
              PyPrintAccumulator printAccum;
              printAccum.parts.append("NamedAttribute(");
-             printAccum.parts.append(self.namedAttr.name);
+             printAccum.parts.append(self.namedAttr.name.data);
              printAccum.parts.append("=");
              mlirAttributePrint(self.namedAttr.attribute,
                                 printAccum.getCallback(),
@@ -2951,7 +2960,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_property_readonly(
           "name",
           [](PyNamedAttribute &self) {
-            return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
+            return py::str(self.namedAttr.name.data,
+                           self.namedAttr.name.length);
           },
           "The name of the NamedAttribute binding")
       .def_property_readonly(
@@ -2983,7 +2993,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_static(
           "parse",
           [](std::string typeSpec, DefaultingPyMlirContext context) {
-            MlirType type = mlirTypeParseGet(context->get(), typeSpec.c_str());
+            MlirType type =
+                mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
             // TODO: Rework error reporting once diagnostic engine is exposed
             // in C API.
             if (mlirTypeIsNull(type)) {

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 25cbba282129..4116e9f30b6b 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -16,7 +16,6 @@
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
 
-
 namespace mlir {
 namespace python {
 
@@ -115,10 +114,11 @@ struct PyPrintAccumulator {
   void *getUserData() { return this; }
 
   MlirStringCallback getCallback() {
-    return [](const char *part, intptr_t size, void *userData) {
+    return [](MlirStringRef part, void *userData) {
       PyPrintAccumulator *printAccum =
           static_cast<PyPrintAccumulator *>(userData);
-      pybind11::str pyPart(part, size); // Decodes as UTF-8 by default.
+      pybind11::str pyPart(part.data,
+                           part.length); // Decodes as UTF-8 by default.
       printAccum->parts.append(std::move(pyPart));
     };
   }
@@ -139,15 +139,16 @@ class PyFileAccumulator {
   void *getUserData() { return this; }
 
   MlirStringCallback getCallback() {
-    return [](const char *part, intptr_t size, void *userData) {
+    return [](MlirStringRef part, void *userData) {
       pybind11::gil_scoped_acquire();
       PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
       if (accum->binary) {
         // Note: Still has to copy and not avoidable with this API.
-        pybind11::bytes pyBytes(part, size);
+        pybind11::bytes pyBytes(part.data, part.length);
         accum->pyWriteFunction(pyBytes);
       } else {
-        pybind11::str pyStr(part, size); // Decodes as UTF-8 by default.
+        pybind11::str pyStr(part.data,
+                            part.length); // Decodes as UTF-8 by default.
         accum->pyWriteFunction(pyStr);
       }
     };
@@ -165,13 +166,13 @@ struct PySinglePartStringAccumulator {
   void *getUserData() { return this; }
 
   MlirStringCallback getCallback() {
-    return [](const char *part, intptr_t size, void *userData) {
+    return [](MlirStringRef part, void *userData) {
       PySinglePartStringAccumulator *accum =
           static_cast<PySinglePartStringAccumulator *>(userData);
       assert(!accum->invoked &&
              "PySinglePartStringAccumulator called back multiple times");
       accum->invoked = true;
-      accum->value = pybind11::str(part, size);
+      accum->value = pybind11::str(part.data, part.length);
     };
   }
 


        


More information about the Mlir-commits mailing list