[llvm-branch-commits] [mlir] ba0fe76 - [mlir][Python] Add an Operation.result property.

Stella Laurenzo via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Nov 29 18:16:01 PST 2020


Author: Stella Laurenzo
Date: 2020-11-29T18:09:07-08:00
New Revision: ba0fe76b7eb87f91499931e76317ddd1cb493aa1

URL: https://github.com/llvm/llvm-project/commit/ba0fe76b7eb87f91499931e76317ddd1cb493aa1
DIFF: https://github.com/llvm/llvm-project/commit/ba0fe76b7eb87f91499931e76317ddd1cb493aa1.diff

LOG: [mlir][Python] Add an Operation.result property.

* If ODS redefines this, it is fine, but I have found this accessor to be universally useful in the old npcomp bindings and I'm closing gaps that will let me switch.

Differential Revision: https://reviews.llvm.org/D92287

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/test/Bindings/Python/ir_operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index d34fe998583f..d270e44debae 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -23,6 +23,8 @@ using namespace mlir;
 using namespace mlir::python;
 
 using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
 
 //------------------------------------------------------------------------------
 // Docstrings (trivial, non-duplicated docstrings are included inline).
@@ -631,7 +633,7 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key,
       getContext()->get(), {canonKey->data(), canonKey->size()});
   if (mlirDialectIsNull(dialect)) {
     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
-                     llvm::Twine("Dialect '") + key + "' not found");
+                     Twine("Dialect '") + key + "' not found");
   }
   return dialect;
 }
@@ -793,7 +795,7 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
   return created;
 }
 
-void PyOperation::checkValid() {
+void PyOperation::checkValid() const {
   if (!valid) {
     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
   }
@@ -817,7 +819,7 @@ void PyOperationBase::print(py::object fileObject, bool binary,
 
   PyFileAccumulator accum(fileObject, binary);
   py::gil_scoped_release();
-  mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(),
+  mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
                               accum.getUserData());
   mlirOpPrintingFlagsDestroy(flags);
 }
@@ -975,7 +977,7 @@ py::object PyOperation::createOpView() {
   MlirIdentifier ident = mlirOperationGetName(get());
   MlirStringRef identStr = mlirIdentifierStr(ident);
   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
-      llvm::StringRef(identStr.data, identStr.length));
+      StringRef(identStr.data, identStr.length));
   if (opViewClass)
     return (*opViewClass)(getRef().getObject());
   return py::cast(PyOpView(getRef().getObject()));
@@ -1044,7 +1046,7 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) {
     (*refOperation)->checkValid();
     beforeOp = (*refOperation)->get();
   }
-  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation.get());
+  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
   operation.setAttached();
 }
 
@@ -1158,7 +1160,7 @@ class PyConcreteValue : public PyValue {
   static MlirValue castFrom(PyValue &orig) {
     if (!DerivedTy::isaFunction(orig.get())) {
       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
-      throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast value to ") +
+      throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
                                              DerivedTy::pyClassName +
                                              " (from " + origRepr + ")");
     }
@@ -1416,9 +1418,9 @@ class PyConcreteAttribute : public BaseTy {
   static MlirAttribute castFrom(PyAttribute &orig) {
     if (!DerivedTy::isaFunction(orig)) {
       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
-      throw SetPyError(PyExc_ValueError,
-                       llvm::Twine("Cannot cast attribute to ") +
-                           DerivedTy::pyClassName + " (from " + origRepr + ")");
+      throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
+                                             DerivedTy::pyClassName +
+                                             " (from " + origRepr + ")");
     }
     return orig;
   }
@@ -1449,7 +1451,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
           // in C API.
           if (mlirAttributeIsNull(attr)) {
             throw SetPyError(PyExc_ValueError,
-                             llvm::Twine("invalid '") +
+                             Twine("invalid '") +
                                  py::repr(py::cast(type)).cast<std::string>() +
                                  "' and expected floating point type.");
           }
@@ -1943,7 +1945,7 @@ class PyConcreteType : public BaseTy {
   static MlirType castFrom(PyType &orig) {
     if (!DerivedTy::isaFunction(orig)) {
       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
-      throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
+      throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
                                              DerivedTy::pyClassName +
                                              " (from " + origRepr + ")");
     }
@@ -2142,7 +2144,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
           }
           throw SetPyError(
               PyExc_ValueError,
-              llvm::Twine("invalid '") +
+              Twine("invalid '") +
                   py::repr(py::cast(elementType)).cast<std::string>() +
                   "' and expected floating point or integer type.");
         },
@@ -2247,7 +2249,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
           if (mlirTypeIsNull(t)) {
             throw SetPyError(
                 PyExc_ValueError,
-                llvm::Twine("invalid '") +
+                Twine("invalid '") +
                     py::repr(py::cast(elementType)).cast<std::string>() +
                     "' and expected floating point or integer type.");
           }
@@ -2278,7 +2280,7 @@ class PyRankedTensorType
           if (mlirTypeIsNull(t)) {
             throw SetPyError(
                 PyExc_ValueError,
-                llvm::Twine("invalid '") +
+                Twine("invalid '") +
                     py::repr(py::cast(elementType)).cast<std::string>() +
                     "' and expected floating point, integer, vector or "
                     "complex "
@@ -2309,7 +2311,7 @@ class PyUnrankedTensorType
           if (mlirTypeIsNull(t)) {
             throw SetPyError(
                 PyExc_ValueError,
-                llvm::Twine("invalid '") +
+                Twine("invalid '") +
                     py::repr(py::cast(elementType)).cast<std::string>() +
                     "' and expected floating point, integer, vector or "
                     "complex "
@@ -2344,7 +2346,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
            if (mlirTypeIsNull(t)) {
              throw SetPyError(
                  PyExc_ValueError,
-                 llvm::Twine("invalid '") +
+                 Twine("invalid '") +
                      py::repr(py::cast(elementType)).cast<std::string>() +
                      "' and expected floating point, integer, vector or "
                      "complex "
@@ -2390,7 +2392,7 @@ class PyUnrankedMemRefType
            if (mlirTypeIsNull(t)) {
              throw SetPyError(
                  PyExc_ValueError,
-                 llvm::Twine("invalid '") +
+                 Twine("invalid '") +
                      py::repr(py::cast(elementType)).cast<std::string>() +
                      "' and expected floating point, integer, vector or "
                      "complex "
@@ -2544,7 +2546,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
                 self.get(), {name.data(), name.size()});
             if (mlirDialectIsNull(dialect)) {
               throw SetPyError(PyExc_ValueError,
-                               llvm::Twine("Dialect '") + name + "' not found");
+                               Twine("Dialect '") + name + "' not found");
             }
             return PyDialectDescriptor(self.getRef(), dialect);
           },
@@ -2763,6 +2765,26 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             return PyOpResultList(self.getOperation().getRef());
           },
           "Returns the list of Operation results.")
+      .def_property_readonly(
+          "result",
+          [](PyOperationBase &self) {
+            auto &operation = self.getOperation();
+            auto numResults = mlirOperationGetNumResults(operation);
+            if (numResults != 1) {
+              auto name = mlirIdentifierStr(mlirOperationGetName(operation));
+              throw SetPyError(
+                  PyExc_ValueError,
+                  Twine("Cannot call .result on operation ") +
+                      StringRef(name.data, name.length) + " which has " +
+                      Twine(numResults) +
+                      " results (it is only valid for operations with a "
+                      "single result)");
+            }
+            return PyOpResult(operation.getRef(),
+                              mlirOperationGetResult(operation, 0));
+          },
+          "Shortcut to get an op result if it has only one (throws an error "
+          "otherwise).")
       .def("__iter__",
            [](PyOperationBase &self) {
              return PyRegionIterator(self.getOperation().getRef());
@@ -2931,7 +2953,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             // in C API.
             if (mlirAttributeIsNull(type)) {
               throw SetPyError(PyExc_ValueError,
-                               llvm::Twine("Unable to parse attribute: '") +
+                               Twine("Unable to parse attribute: '") +
                                    attrSpec + "'");
             }
             return PyAttribute(context->getRef(), type);
@@ -3042,8 +3064,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             // in C API.
             if (mlirTypeIsNull(type)) {
               throw SetPyError(PyExc_ValueError,
-                               llvm::Twine("Unable to parse type: '") +
-                                   typeSpec + "'");
+                               Twine("Unable to parse type: '") + typeSpec +
+                                   "'");
             }
             return PyType(context->getRef(), type);
           },

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index d24607fb02c2..0cdc7e6a66fe 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -425,7 +425,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
                  pybind11::object parentKeepAlive = pybind11::object());
 
   /// Gets the backing operation.
-  MlirOperation get() {
+  operator MlirOperation() const { return get(); }
+  MlirOperation get() const {
     checkValid();
     return operation;
   }
@@ -440,7 +441,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
     assert(!attached && "operation already attached");
     attached = true;
   }
-  void checkValid();
+  void checkValid() const;
 
   /// Gets the owning block or raises an exception if the operation has no
   /// owning block.

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index ddc4c2129844..e3867b99a9b4 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -474,6 +474,7 @@ def testOperationPrint():
 run(testOperationPrint)
 
 
+# CHECK-LABEL: TEST: testKnownOpView
 def testKnownOpView():
   with Context(), Location.unknown():
     Context.current.allow_unregistered_dialects = True
@@ -503,3 +504,36 @@ def testKnownOpView():
     print(repr(custom))
 
 run(testKnownOpView)
+
+
+# CHECK-LABEL: TEST: testSingleResultProperty
+def testSingleResultProperty():
+  with Context(), Location.unknown():
+    Context.current.allow_unregistered_dialects = True
+    module = Module.parse(r"""
+      "custom.no_result"() : () -> ()
+      %0:2 = "custom.two_result"() : () -> (f32, f32)
+      %1 = "custom.one_result"() : () -> f32
+    """)
+    print(module)
+
+  try:
+    module.body.operations[0].result
+  except ValueError as e:
+    # CHECK: Cannot call .result on operation custom.no_result which has 0 results
+    print(e)
+  else:
+    assert False, "Expected exception"
+
+  try:
+    module.body.operations[1].result
+  except ValueError as e:
+    # CHECK: Cannot call .result on operation custom.two_result which has 2 results
+    print(e)
+  else:
+    assert False, "Expected exception"
+
+  # CHECK: %1 = "custom.one_result"() : () -> f32
+  print(module.body.operations[2])
+
+run(testSingleResultProperty)


        


More information about the llvm-branch-commits mailing list