[Mlir-commits] [mlir] 5f0c1e3 - [mlir][Python] Sync Python bindings with C API MlirStringRef modification.
Stella Laurenzo
llvmlistbot at llvm.org
Tue Nov 24 11:35:36 PST 2020
Author: zhanghb97
Date: 2020-11-24T19:33:48Z
New Revision: 5f0c1e380661f5a28fb8e87d70a68fa31e923436
URL: https://github.com/llvm/llvm-project/commit/5f0c1e380661f5a28fb8e87d70a68fa31e923436
DIFF: https://github.com/llvm/llvm-project/commit/5f0c1e380661f5a28fb8e87d70a68fa31e923436.diff
LOG: [mlir][Python] Sync Python bindings with C API MlirStringRef modification.
MLIR C API use the `MlirStringRef` instead of `const char *` for the string type now. This patch sync the Python bindings with the C API modification.
Differential Revision: https://reviews.llvm.org/D92007
Added:
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/PybindUtils.h
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 1e848c2d1531..e145a58d0d27 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -145,6 +145,11 @@ createCustomDialectWrapper(const std::string &dialectNamespace,
// Create the custom implementation.
return (*dialectClass)(std::move(dialectDescriptor));
}
+
+static MlirStringRef toMlirStringRef(const std::string &s) {
+ return mlirStringRefCreate(s.data(), s.size());
+}
+
//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
@@ -902,7 +907,8 @@ py::object PyOperation::create(
// Apply unpacked/validated to the operation state. Beyond this
// point, exceptions cannot be thrown or else the state will leak.
- MlirOperationState state = mlirOperationStateGet(name.c_str(), location->loc);
+ MlirOperationState state =
+ mlirOperationStateGet(toMlirStringRef(name), location->loc);
if (!mlirOperands.empty())
mlirOperationStateAddOperands(&state, mlirOperands.size(),
mlirOperands.data());
@@ -917,7 +923,7 @@ py::object PyOperation::create(
mlirNamedAttributes.reserve(mlirAttributes.size());
for (auto &it : mlirAttributes)
mlirNamedAttributes.push_back(
- mlirNamedAttributeGet(it.first.c_str(), it.second));
+ mlirNamedAttributeGet(toMlirStringRef(it.first), it.second));
mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
mlirNamedAttributes.data());
}
@@ -1076,7 +1082,7 @@ bool PyAttribute::operator==(const PyAttribute &other) {
PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
: ownedName(new std::string(std::move(ownedName))) {
- namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
+ namedAttr = mlirNamedAttributeGet(toMlirStringRef(*this->ownedName), attr);
}
//------------------------------------------------------------------------------
@@ -1287,8 +1293,8 @@ class PyOpAttributeMap {
PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
PyAttribute dunderGetItemNamed(const std::string &name) {
- MlirAttribute attr =
- mlirOperationGetAttributeByName(operation->get(), name.c_str());
+ MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
+ toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
throw SetPyError(PyExc_KeyError,
"attempt to access a non-existent attribute");
@@ -1303,16 +1309,18 @@ class PyOpAttributeMap {
}
MlirNamedAttribute namedAttr =
mlirOperationGetAttribute(operation->get(), index);
- return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name));
+ return PyNamedAttribute(namedAttr.attribute,
+ std::string(namedAttr.name.data));
}
void dunderSetItem(const std::string &name, PyAttribute attr) {
- mlirOperationSetAttributeByName(operation->get(), name.c_str(), attr.attr);
+ mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
+ attr.attr);
}
void dunderDelItem(const std::string &name) {
- int removed =
- mlirOperationRemoveAttributeByName(operation->get(), name.c_str());
+ int removed = mlirOperationRemoveAttributeByName(operation->get(),
+ toMlirStringRef(name));
if (!removed)
throw SetPyError(PyExc_KeyError,
"attempt to delete a non-existent attribute");
@@ -1323,8 +1331,8 @@ class PyOpAttributeMap {
}
bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(
- mlirOperationGetAttributeByName(operation->get(), name.c_str()));
+ return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
+ operation->get(), toMlirStringRef(name)));
}
static void bind(py::module &m) {
@@ -2599,9 +2607,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
"file",
[](std::string filename, int line, int col,
DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationFileLineColGet(
- context->get(), filename.c_str(), line, col));
+ return PyLocation(
+ context->getRef(),
+ mlirLocationFileLineColGet(
+ context->get(), toMlirStringRef(filename), line, col));
},
py::arg("filename"), py::arg("line"), py::arg("col"),
py::arg("context") = py::none(), kContextGetFileLocationDocstring)
@@ -2625,8 +2634,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_static(
"parse",
[](const std::string moduleAsm, DefaultingPyMlirContext context) {
- MlirModule module =
- mlirModuleCreateParse(context->get(), moduleAsm.c_str());
+ MlirModule module = mlirModuleCreateParse(
+ context->get(), toMlirStringRef(moduleAsm));
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirModuleIsNull(module)) {
@@ -2875,8 +2884,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_static(
"parse",
[](std::string attrSpec, DefaultingPyMlirContext context) {
- MlirAttribute type =
- mlirAttributeParseGet(context->get(), attrSpec.c_str());
+ MlirAttribute type = mlirAttributeParseGet(
+ context->get(), toMlirStringRef(attrSpec));
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirAttributeIsNull(type)) {
@@ -2940,7 +2949,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
[](PyNamedAttribute &self) {
PyPrintAccumulator printAccum;
printAccum.parts.append("NamedAttribute(");
- printAccum.parts.append(self.namedAttr.name);
+ printAccum.parts.append(self.namedAttr.name.data);
printAccum.parts.append("=");
mlirAttributePrint(self.namedAttr.attribute,
printAccum.getCallback(),
@@ -2951,7 +2960,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_property_readonly(
"name",
[](PyNamedAttribute &self) {
- return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
+ return py::str(self.namedAttr.name.data,
+ self.namedAttr.name.length);
},
"The name of the NamedAttribute binding")
.def_property_readonly(
@@ -2983,7 +2993,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_static(
"parse",
[](std::string typeSpec, DefaultingPyMlirContext context) {
- MlirType type = mlirTypeParseGet(context->get(), typeSpec.c_str());
+ MlirType type =
+ mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(type)) {
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 25cbba282129..4116e9f30b6b 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -16,7 +16,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
-
namespace mlir {
namespace python {
@@ -115,10 +114,11 @@ struct PyPrintAccumulator {
void *getUserData() { return this; }
MlirStringCallback getCallback() {
- return [](const char *part, intptr_t size, void *userData) {
+ return [](MlirStringRef part, void *userData) {
PyPrintAccumulator *printAccum =
static_cast<PyPrintAccumulator *>(userData);
- pybind11::str pyPart(part, size); // Decodes as UTF-8 by default.
+ pybind11::str pyPart(part.data,
+ part.length); // Decodes as UTF-8 by default.
printAccum->parts.append(std::move(pyPart));
};
}
@@ -139,15 +139,16 @@ class PyFileAccumulator {
void *getUserData() { return this; }
MlirStringCallback getCallback() {
- return [](const char *part, intptr_t size, void *userData) {
+ return [](MlirStringRef part, void *userData) {
pybind11::gil_scoped_acquire();
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
if (accum->binary) {
// Note: Still has to copy and not avoidable with this API.
- pybind11::bytes pyBytes(part, size);
+ pybind11::bytes pyBytes(part.data, part.length);
accum->pyWriteFunction(pyBytes);
} else {
- pybind11::str pyStr(part, size); // Decodes as UTF-8 by default.
+ pybind11::str pyStr(part.data,
+ part.length); // Decodes as UTF-8 by default.
accum->pyWriteFunction(pyStr);
}
};
@@ -165,13 +166,13 @@ struct PySinglePartStringAccumulator {
void *getUserData() { return this; }
MlirStringCallback getCallback() {
- return [](const char *part, intptr_t size, void *userData) {
+ return [](MlirStringRef part, void *userData) {
PySinglePartStringAccumulator *accum =
static_cast<PySinglePartStringAccumulator *>(userData);
assert(!accum->invoked &&
"PySinglePartStringAccumulator called back multiple times");
accum->invoked = true;
- accum->value = pybind11::str(part, size);
+ accum->value = pybind11::str(part.data, part.length);
};
}
More information about the Mlir-commits
mailing list