[Mlir-commits] [mlir] 580915d - [mlir] Expose Value hierarchy to Python bindings

Alex Zinenko llvmlistbot at llvm.org
Wed Oct 21 00:52:52 PDT 2020


Author: Alex Zinenko
Date: 2020-10-21T09:49:22+02:00
New Revision: 580915d6a2970022d5b7e05d4587de0fd7126c31

URL: https://github.com/llvm/llvm-project/commit/580915d6a2970022d5b7e05d4587de0fd7126c31
DIFF: https://github.com/llvm/llvm-project/commit/580915d6a2970022d5b7e05d4587de0fd7126c31.diff

LOG: [mlir] Expose Value hierarchy to Python bindings

Values are ubiquitous in the IR, in particular block argument and operation
results are Values. Define Python classes for BlockArgument, OpResult and their
common ancestor Value. Define pseudo-container classes for lists of block
arguments and operation results, and use these containers to access the
corresponding values in blocks and operations.

Differential Revision: https://reviews.llvm.org/D89778

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/Bindings/Python/ir_operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 816123472647..2aeb306f7256 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -432,6 +432,9 @@ intptr_t mlirOpResultGetResultNumber(MlirValue value);
 /** Returns the type of the value. */
 MlirType mlirValueGetType(MlirValue value);
 
+/** Prints the value to the standard error stream. */
+void mlirValueDump(MlirValue value);
+
 /** Prints a value by sending chunks of the string representation and
  * forwarding `userData to `callback`. Note that the callback may be called
  * several times with consecutive chunks of the string. */

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 2a768df0ffd9..0c3e541d18b2 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -85,6 +85,14 @@ static const char kAppendBlockDocstring[] =
   The created block.
 )";
 
+static const char kValueDunderStrDocstring[] =
+    R"(Returns the string form of the value.
+
+If the value is a block argument, this is the assembly form of its type and the
+position in the argument list. If the value is an operation result, this is
+equivalent to printing the operation that produced it.
+)";
+
 //------------------------------------------------------------------------------
 // Conversion utilities.
 //------------------------------------------------------------------------------
@@ -732,6 +740,168 @@ bool PyType::operator==(const PyType &other) {
   return mlirTypeEqual(type, other.type);
 }
 
+//------------------------------------------------------------------------------
+// PyValue and subclases.
+//------------------------------------------------------------------------------
+
+namespace {
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy> class PyConcreteValue : public PyValue {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  // and redefine bindDerived.
+  using ClassTy = py::class_<DerivedTy, PyValue>;
+  using IsAFunctionTy = int (*)(MlirValue);
+
+  PyConcreteValue() = default;
+  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+      : PyValue(operationRef, value) {}
+  PyConcreteValue(PyValue &orig)
+      : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+  /// Attempts to cast the original value to the derived type and throws on
+  /// type mismatches.
+  static MlirValue castFrom(PyValue &orig) {
+    if (!DerivedTy::isaFunction(orig.get())) {
+      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast value to ") +
+                                             DerivedTy::pyClassName +
+                                             " (from " + origRepr + ")");
+    }
+    return orig.get();
+  }
+
+  /// Binds the Python module objects to functions of this class.
+  static void bind(py::module &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+/// Python wrapper for MlirBlockArgument.
+class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
+  static constexpr const char *pyClassName = "BlockArgument";
+  using PyConcreteValue::PyConcreteValue;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly("owner", [](PyBlockArgument &self) {
+      return PyBlock(self.getParentOperation(),
+                     mlirBlockArgumentGetOwner(self.get()));
+    });
+    c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
+      return mlirBlockArgumentGetArgNumber(self.get());
+    });
+    c.def("set_type", [](PyBlockArgument &self, PyType type) {
+      return mlirBlockArgumentSetType(self.get(), type);
+    });
+  }
+};
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+  static constexpr const char *pyClassName = "OpResult";
+  using PyConcreteValue::PyConcreteValue;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly("owner", [](PyOpResult &self) {
+      assert(
+          mlirOperationEqual(self.getParentOperation()->get(),
+                             mlirOpResultGetOwner(self.get())) &&
+          "expected the owner of the value in Python to match that in the IR");
+      return self.getParentOperation();
+    });
+    c.def_property_readonly("result_number", [](PyOpResult &self) {
+      return mlirOpResultGetResultNumber(self.get());
+    });
+  }
+};
+
+/// A list of block arguments. Internally, these are stored as consecutive
+/// elements, random access is cheap. The argument list is associated with the
+/// operation that contains the block (detached blocks are not allowed in
+/// Python bindings) and extends its lifetime.
+class PyBlockArgumentList {
+public:
+  PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
+      : operation(std::move(operation)), block(block) {}
+
+  /// Returns the length of the block argument list.
+  intptr_t dunderLen() {
+    operation->checkValid();
+    return mlirBlockGetNumArguments(block);
+  }
+
+  /// Returns `index`-th element of the block argument list.
+  PyBlockArgument dunderGetItem(intptr_t index) {
+    if (index < 0 || index >= dunderLen()) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds region");
+    }
+    PyValue value(operation, mlirBlockGetArgument(block, index));
+    return PyBlockArgument(value);
+  }
+
+  /// Defines a Python class in the bindings.
+  static void bind(py::module &m) {
+    py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
+        .def("__len__", &PyBlockArgumentList::dunderLen)
+        .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
+  }
+
+private:
+  PyOperationRef operation;
+  MlirBlock block;
+};
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The result list is associated with the
+/// operation whose results these are, and extends the lifetime of this
+/// operation.
+class PyOpResultList {
+public:
+  PyOpResultList(PyOperationRef operation) : operation(operation) {}
+
+  /// Returns the length of the result list.
+  intptr_t dunderLen() {
+    operation->checkValid();
+    return mlirOperationGetNumResults(operation->get());
+  }
+
+  /// Returns `index`-th element in the result list.
+  PyOpResult dunderGetItem(intptr_t index) {
+    if (index < 0 || index >= dunderLen()) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds region");
+    }
+    PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+    return PyOpResult(value);
+  }
+
+  /// Defines a Python class in the bindings.
+  static void bind(py::module &m) {
+    py::class_<PyOpResultList>(m, "OpResultList")
+        .def("__len__", &PyOpResultList::dunderLen)
+        .def("__getitem__", &PyOpResultList::dunderGetItem);
+  }
+
+private:
+  PyOperationRef operation;
+};
+
+} // end namespace
+
 //------------------------------------------------------------------------------
 // Standard attribute subclasses.
 //------------------------------------------------------------------------------
@@ -1793,6 +1963,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_property_readonly(
           "regions",
           [](PyOperation &self) { return PyRegionList(self.getRef()); })
+      .def_property_readonly(
+          "results",
+          [](PyOperation &self) { return PyOpResultList(self.getRef()); },
+          "Returns the list of Operation results.")
       .def("__iter__",
            [](PyOperation &self) { return PyRegionIterator(self.getRef()); })
       .def(
@@ -1833,6 +2007,12 @@ void mlir::python::populateIRSubmodule(py::module &m) {
 
   // Mapping of PyBlock.
   py::class_<PyBlock>(m, "Block")
+      .def_property_readonly(
+          "arguments",
+          [](PyBlock &self) {
+            return PyBlockArgumentList(self.getParentOperation(), self.get());
+          },
+          "Returns a list of block arguments.")
       .def_property_readonly(
           "operations",
           [](PyBlock &self) {
@@ -2015,11 +2195,40 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyTupleType::bind(m);
   PyFunctionType::bind(m);
 
+  // Mapping of Value.
+  py::class_<PyValue>(m, "Value")
+      .def_property_readonly(
+          "context",
+          [](PyValue &self) { return self.getParentOperation()->getContext(); },
+          "Context in which the value lives.")
+      .def(
+          "dump", [](PyValue &self) { mlirValueDump(self.get()); },
+          kDumpDocstring)
+      .def(
+          "__str__",
+          [](PyValue &self) {
+            PyPrintAccumulator printAccum;
+            printAccum.parts.append("Value(");
+            mlirValuePrint(self.get(), printAccum.getCallback(),
+                           printAccum.getUserData());
+            printAccum.parts.append(")");
+            return printAccum.join();
+          },
+          kValueDunderStrDocstring)
+      .def_property_readonly("type", [](PyValue &self) {
+        return PyType(self.getParentOperation()->getContext(),
+                      mlirValueGetType(self.get()));
+      });
+  PyBlockArgument::bind(m);
+  PyOpResult::bind(m);
+
   // Container bindings.
+  PyBlockArgumentList::bind(m);
   PyBlockIterator::bind(m);
   PyBlockList::bind(m);
   PyOperationIterator::bind(m);
   PyOperationList::bind(m);
+  PyOpResultList::bind(m);
   PyRegionIterator::bind(m);
   PyRegionList::bind(m);
 }

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index c175018c8bb6..947b7343e35a 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -23,6 +23,7 @@ class PyMlirContext;
 class PyModule;
 class PyOperation;
 class PyType;
+class PyValue;
 
 /// Template for a reference to a concrete type which captures a python
 /// reference to its underlying python object.
@@ -381,6 +382,27 @@ class PyType : public BaseContextObject {
   MlirType type;
 };
 
+/// Wrapper around the generic MlirValue.
+/// Values are managed completely by the operation that resulted in their
+/// definition. For op result value, this is the operation that defines the
+/// value. For block argument values, this is the operation that contains the
+/// block to which the value is an argument (blocks cannot be detached in Python
+/// bindings so such operation always exists).
+class PyValue {
+public:
+  PyValue(PyOperationRef parentOperation, MlirValue value)
+      : parentOperation(parentOperation), value(value) {}
+
+  MlirValue get() { return value; }
+  PyOperationRef &getParentOperation() { return parentOperation; }
+
+  void checkValid() { return parentOperation->checkValid(); }
+
+private:
+  PyOperationRef parentOperation;
+  MlirValue value;
+};
+
 void populateIRSubmodule(pybind11::module &m);
 
 } // namespace python

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 4bae43c424fd..104f6fda5c02 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -454,6 +454,8 @@ MlirType mlirValueGetType(MlirValue value) {
   return wrap(unwrap(value).getType());
 }
 
+void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
+
 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
                     void *userData) {
   detail::CallbackOstream stream(callback, userData);

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 37b830558528..e4dc71ac26ef 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -102,6 +102,35 @@ def walk_operations(indent, op):
 run(testTraverseOpRegionBlockIndices)
 
 
+# CHECK-LABEL: TEST: testBlockArgumentList
+def testBlockArgumentList():
+  ctx = mlir.ir.Context()
+  module = ctx.parse_module(r"""
+    func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
+      return
+    }
+  """)
+  func = module.operation.regions[0].blocks[0].operations[0]
+  entry_block = func.regions[0].blocks[0]
+  assert len(entry_block.arguments) == 3
+  # CHECK: Argument 0, type i32
+  # CHECK: Argument 1, type f64
+  # CHECK: Argument 2, type index
+  for arg in entry_block.arguments:
+    print(f"Argument {arg.arg_number}, type {arg.type}")
+    new_type = mlir.ir.IntegerType.get_signless(ctx, 8 * (arg.arg_number + 1))
+    arg.set_type(new_type)
+
+  # CHECK: Argument 0, type i8
+  # CHECK: Argument 1, type i16
+  # CHECK: Argument 2, type i24
+  for arg in entry_block.arguments:
+    print(f"Argument {arg.arg_number}, type {arg.type}")
+
+
+run(testBlockArgumentList)
+
+
 # CHECK-LABEL: TEST: testDetachedOperation
 def testDetachedOperation():
   ctx = mlir.ir.Context()
@@ -196,3 +225,26 @@ def testOperationWithRegion():
   print(module)
 
 run(testOperationWithRegion)
+
+
+# CHECK-LABEL: TEST: testOperationResultList
+def testOperationResultList():
+  ctx = mlir.ir.Context()
+  module = ctx.parse_module(r"""
+    func @f1() {
+      %0:3 = call @f2() : () -> (i32, f64, index)
+      return
+    }
+    func @f2() -> (i32, f64, index)
+  """)
+  caller = module.operation.regions[0].blocks[0].operations[0]
+  call = caller.regions[0].blocks[0].operations[0]
+  assert len(call.results) == 3
+  # CHECK: Result 0, type i32
+  # CHECK: Result 1, type f64
+  # CHECK: Result 2, type index
+  for res in call.results:
+    print(f"Result {res.result_number}, type {res.type}")
+
+
+run(testOperationResultList)


        


More information about the Mlir-commits mailing list