[llvm-branch-commits] [mlir] ba0fe76 - [mlir][Python] Add an Operation.result property.
Stella Laurenzo via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Nov 29 18:16:01 PST 2020
Author: Stella Laurenzo
Date: 2020-11-29T18:09:07-08:00
New Revision: ba0fe76b7eb87f91499931e76317ddd1cb493aa1
URL: https://github.com/llvm/llvm-project/commit/ba0fe76b7eb87f91499931e76317ddd1cb493aa1
DIFF: https://github.com/llvm/llvm-project/commit/ba0fe76b7eb87f91499931e76317ddd1cb493aa1.diff
LOG: [mlir][Python] Add an Operation.result property.
* If ODS redefines this, it is fine, but I have found this accessor to be universally useful in the old npcomp bindings and I'm closing gaps that will let me switch.
Differential Revision: https://reviews.llvm.org/D92287
Added:
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/test/Bindings/Python/ir_operation.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index d34fe998583f..d270e44debae 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -23,6 +23,8 @@ using namespace mlir;
using namespace mlir::python;
using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
//------------------------------------------------------------------------------
// Docstrings (trivial, non-duplicated docstrings are included inline).
@@ -631,7 +633,7 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key,
getContext()->get(), {canonKey->data(), canonKey->size()});
if (mlirDialectIsNull(dialect)) {
throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
- llvm::Twine("Dialect '") + key + "' not found");
+ Twine("Dialect '") + key + "' not found");
}
return dialect;
}
@@ -793,7 +795,7 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
return created;
}
-void PyOperation::checkValid() {
+void PyOperation::checkValid() const {
if (!valid) {
throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
}
@@ -817,7 +819,7 @@ void PyOperationBase::print(py::object fileObject, bool binary,
PyFileAccumulator accum(fileObject, binary);
py::gil_scoped_release();
- mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(),
+ mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
accum.getUserData());
mlirOpPrintingFlagsDestroy(flags);
}
@@ -975,7 +977,7 @@ py::object PyOperation::createOpView() {
MlirIdentifier ident = mlirOperationGetName(get());
MlirStringRef identStr = mlirIdentifierStr(ident);
auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
- llvm::StringRef(identStr.data, identStr.length));
+ StringRef(identStr.data, identStr.length));
if (opViewClass)
return (*opViewClass)(getRef().getObject());
return py::cast(PyOpView(getRef().getObject()));
@@ -1044,7 +1046,7 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) {
(*refOperation)->checkValid();
beforeOp = (*refOperation)->get();
}
- mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation.get());
+ mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
operation.setAttached();
}
@@ -1158,7 +1160,7 @@ class PyConcreteValue : public PyValue {
static MlirValue castFrom(PyValue &orig) {
if (!DerivedTy::isaFunction(orig.get())) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast value to ") +
+ throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
DerivedTy::pyClassName +
" (from " + origRepr + ")");
}
@@ -1416,9 +1418,9 @@ class PyConcreteAttribute : public BaseTy {
static MlirAttribute castFrom(PyAttribute &orig) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw SetPyError(PyExc_ValueError,
- llvm::Twine("Cannot cast attribute to ") +
- DerivedTy::pyClassName + " (from " + origRepr + ")");
+ throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
+ DerivedTy::pyClassName +
+ " (from " + origRepr + ")");
}
return orig;
}
@@ -1449,7 +1451,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
// in C API.
if (mlirAttributeIsNull(attr)) {
throw SetPyError(PyExc_ValueError,
- llvm::Twine("invalid '") +
+ Twine("invalid '") +
py::repr(py::cast(type)).cast<std::string>() +
"' and expected floating point type.");
}
@@ -1943,7 +1945,7 @@ class PyConcreteType : public BaseTy {
static MlirType castFrom(PyType &orig) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
+ throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
DerivedTy::pyClassName +
" (from " + origRepr + ")");
}
@@ -2142,7 +2144,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
}
throw SetPyError(
PyExc_ValueError,
- llvm::Twine("invalid '") +
+ Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
},
@@ -2247,7 +2249,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
- llvm::Twine("invalid '") +
+ Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
}
@@ -2278,7 +2280,7 @@ class PyRankedTensorType
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
- llvm::Twine("invalid '") +
+ Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
@@ -2309,7 +2311,7 @@ class PyUnrankedTensorType
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
- llvm::Twine("invalid '") +
+ Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
@@ -2344,7 +2346,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
- llvm::Twine("invalid '") +
+ Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
@@ -2390,7 +2392,7 @@ class PyUnrankedMemRefType
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
- llvm::Twine("invalid '") +
+ Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
@@ -2544,7 +2546,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
self.get(), {name.data(), name.size()});
if (mlirDialectIsNull(dialect)) {
throw SetPyError(PyExc_ValueError,
- llvm::Twine("Dialect '") + name + "' not found");
+ Twine("Dialect '") + name + "' not found");
}
return PyDialectDescriptor(self.getRef(), dialect);
},
@@ -2763,6 +2765,26 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return PyOpResultList(self.getOperation().getRef());
},
"Returns the list of Operation results.")
+ .def_property_readonly(
+ "result",
+ [](PyOperationBase &self) {
+ auto &operation = self.getOperation();
+ auto numResults = mlirOperationGetNumResults(operation);
+ if (numResults != 1) {
+ auto name = mlirIdentifierStr(mlirOperationGetName(operation));
+ throw SetPyError(
+ PyExc_ValueError,
+ Twine("Cannot call .result on operation ") +
+ StringRef(name.data, name.length) + " which has " +
+ Twine(numResults) +
+ " results (it is only valid for operations with a "
+ "single result)");
+ }
+ return PyOpResult(operation.getRef(),
+ mlirOperationGetResult(operation, 0));
+ },
+ "Shortcut to get an op result if it has only one (throws an error "
+ "otherwise).")
.def("__iter__",
[](PyOperationBase &self) {
return PyRegionIterator(self.getOperation().getRef());
@@ -2931,7 +2953,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// in C API.
if (mlirAttributeIsNull(type)) {
throw SetPyError(PyExc_ValueError,
- llvm::Twine("Unable to parse attribute: '") +
+ Twine("Unable to parse attribute: '") +
attrSpec + "'");
}
return PyAttribute(context->getRef(), type);
@@ -3042,8 +3064,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// in C API.
if (mlirTypeIsNull(type)) {
throw SetPyError(PyExc_ValueError,
- llvm::Twine("Unable to parse type: '") +
- typeSpec + "'");
+ Twine("Unable to parse type: '") + typeSpec +
+ "'");
}
return PyType(context->getRef(), type);
},
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index d24607fb02c2..0cdc7e6a66fe 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -425,7 +425,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
pybind11::object parentKeepAlive = pybind11::object());
/// Gets the backing operation.
- MlirOperation get() {
+ operator MlirOperation() const { return get(); }
+ MlirOperation get() const {
checkValid();
return operation;
}
@@ -440,7 +441,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
assert(!attached && "operation already attached");
attached = true;
}
- void checkValid();
+ void checkValid() const;
/// Gets the owning block or raises an exception if the operation has no
/// owning block.
diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index ddc4c2129844..e3867b99a9b4 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -474,6 +474,7 @@ def testOperationPrint():
run(testOperationPrint)
+# CHECK-LABEL: TEST: testKnownOpView
def testKnownOpView():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
@@ -503,3 +504,36 @@ def testKnownOpView():
print(repr(custom))
run(testKnownOpView)
+
+
+# CHECK-LABEL: TEST: testSingleResultProperty
+def testSingleResultProperty():
+ with Context(), Location.unknown():
+ Context.current.allow_unregistered_dialects = True
+ module = Module.parse(r"""
+ "custom.no_result"() : () -> ()
+ %0:2 = "custom.two_result"() : () -> (f32, f32)
+ %1 = "custom.one_result"() : () -> f32
+ """)
+ print(module)
+
+ try:
+ module.body.operations[0].result
+ except ValueError as e:
+ # CHECK: Cannot call .result on operation custom.no_result which has 0 results
+ print(e)
+ else:
+ assert False, "Expected exception"
+
+ try:
+ module.body.operations[1].result
+ except ValueError as e:
+ # CHECK: Cannot call .result on operation custom.two_result which has 2 results
+ print(e)
+ else:
+ assert False, "Expected exception"
+
+ # CHECK: %1 = "custom.one_result"() : () -> f32
+ print(module.body.operations[2])
+
+run(testSingleResultProperty)
More information about the llvm-branch-commits
mailing list