[Mlir-commits] [mlir] 429b0cf - [mlir][python] Directly implement sequence protocol on Sliceable.

Stella Laurenzo llvmlistbot at llvm.org
Mon Feb 14 09:45:34 PST 2022


Author: Stella Laurenzo
Date: 2022-02-14T09:45:17-08:00
New Revision: 429b0cf1de14471c9d258467bfc9936c3a9d52f7

URL: https://github.com/llvm/llvm-project/commit/429b0cf1de14471c9d258467bfc9936c3a9d52f7
DIFF: https://github.com/llvm/llvm-project/commit/429b0cf1de14471c9d258467bfc9936c3a9d52f7.diff

LOG: [mlir][python] Directly implement sequence protocol on Sliceable.

* While annoying, this is the only way to get C++ exception handling out of the happy path for normal iteration.
* Implements sq_length and sq_item for the sequence protocol (used for iteration, including list() construction).
* Implements mp_subscript for general use (i.e. foo[1] and foo[1:1]).
* For constructing a `list(op.results)`, this reduces the time from ~4-5us to ~1.5us on my machine (give or take measurement overhead) and eliminates C++ exceptions, which is a worthy goal in itself.
  * Compared to a baseline of similar construction of a three-integer list, which takes 450ns (might just be measuring function call overhead).
  * See issue discussed on the pybind side: https://github.com/pybind/pybind11/issues/2842

Differential Revision: https://reviews.llvm.org/D119691

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 75a72371e9ae1..e791ba8e214c3 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -207,6 +207,8 @@ struct PySinglePartStringAccumulator {
 ///     constructs a new instance of the derived pseudo-container with the
 ///     given slice parameters (to be forwarded to the Sliceable constructor).
 ///
+/// The getNumElements() and getElement(intptr_t) callbacks must not throw.
+///
 /// A derived class may additionally define:
 ///   - a `static void bindDerived(ClassTy &)` method to bind additional methods
 ///     the python class.
@@ -215,49 +217,53 @@ class Sliceable {
 protected:
   using ClassTy = pybind11::class_<Derived>;
 
+  // Transforms `index` into a legal value to access the underlying sequence.
+  // Returns <0 on failure.
   intptr_t wrapIndex(intptr_t index) {
     if (index < 0)
       index = length + index;
-    if (index < 0 || index >= length) {
-      throw python::SetPyError(PyExc_IndexError,
-                               "attempt to access out of bounds");
-    }
+    if (index < 0 || index >= length)
+      return -1;
     return index;
   }
 
-public:
-  explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
-      : startIndex(startIndex), length(length), step(step) {
-    assert(length >= 0 && "expected non-negative slice length");
-  }
-
-  /// Returns the length of the slice.
-  intptr_t dunderLen() const { return length; }
-
   /// Returns the element at the given slice index. Supports negative indices
-  /// by taking elements in inverse order. Throws if the index is out of bounds.
-  ElementTy dunderGetItem(intptr_t index) {
+  /// by taking elements in inverse order. Returns a nullptr object if out
+  /// of bounds.
+  pybind11::object getItem(intptr_t index) {
     // Negative indices mean we count from the end.
     index = wrapIndex(index);
+    if (index < 0) {
+      PyErr_SetString(PyExc_IndexError, "index out of range");
+      return {};
+    }
 
     // Compute the linear index given the current slice properties.
     int linearIndex = index * step + startIndex;
     assert(linearIndex >= 0 &&
            linearIndex < static_cast<Derived *>(this)->getNumElements() &&
            "linear index out of bounds, the slice is ill-formed");
-    return static_cast<Derived *>(this)->getElement(linearIndex);
+    return pybind11::cast(
+        static_cast<Derived *>(this)->getElement(linearIndex));
   }
 
   /// Returns a new instance of the pseudo-container restricted to the given
-  /// slice.
-  Derived dunderGetItemSlice(pybind11::slice slice) {
+  /// slice. Returns a nullptr object on failure.
+  pybind11::object getItemSlice(PyObject *slice) {
     ssize_t start, stop, extraStep, sliceLength;
-    if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) {
-      throw python::SetPyError(PyExc_IndexError,
-                               "attempt to access out of bounds");
+    if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
+                             &sliceLength) != 0) {
+      PyErr_SetString(PyExc_IndexError, "index out of range");
+      return {};
     }
-    return static_cast<Derived *>(this)->slice(startIndex + start * step,
-                                               sliceLength, step * extraStep);
+    return pybind11::cast(static_cast<Derived *>(this)->slice(
+        startIndex + start * step, sliceLength, step * extraStep));
+  }
+
+public:
+  explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
+      : startIndex(startIndex), length(length), step(step) {
+    assert(length >= 0 && "expected non-negative slice length");
   }
 
   /// Returns a new vector (mapped to Python list) containing elements from two
@@ -267,10 +273,10 @@ class Sliceable {
     std::vector<ElementTy> elements;
     elements.reserve(length + other.length);
     for (intptr_t i = 0; i < length; ++i) {
-      elements.push_back(dunderGetItem(i));
+      elements.push_back(static_cast<Derived *>(this)->getElement(i));
     }
     for (intptr_t i = 0; i < other.length; ++i) {
-      elements.push_back(other.dunderGetItem(i));
+      elements.push_back(static_cast<Derived *>(this)->getElement(i));
     }
     return elements;
   }
@@ -279,11 +285,51 @@ class Sliceable {
   static void bind(pybind11::module &m) {
     auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
                                            pybind11::module_local())
-                     .def("__len__", &Sliceable::dunderLen)
-                     .def("__getitem__", &Sliceable::dunderGetItem)
-                     .def("__getitem__", &Sliceable::dunderGetItemSlice)
                      .def("__add__", &Sliceable::dunderAdd);
     Derived::bindDerived(clazz);
+
+    // Manually implement the sequence protocol via the C API. We do this
+    // because it is approx 4x faster than via pybind11, largely because that
+    // formulation requires a C++ exception to be thrown to detect end of
+    // sequence.
+    // Since we are in a C-context, any C++ exception that happens here
+    // will terminate the program. There is nothing in this implementation
+    // that should throw in a non-terminal way, so we forgo further
+    // exception marshalling.
+    // See: https://github.com/pybind/pybind11/issues/2842
+    auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
+    assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
+           "must be heap type");
+    heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
+      auto self = pybind11::cast<Derived *>(rawSelf);
+      return self->length;
+    };
+    // sq_item is called as part of the sequence protocol for iteration,
+    // list construction, etc.
+    heap_type->as_sequence.sq_item =
+        +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
+      auto self = pybind11::cast<Derived *>(rawSelf);
+      return self->getItem(index).release().ptr();
+    };
+    // mp_subscript is used for both slices and integer lookups.
+    heap_type->as_mapping.mp_subscript =
+        +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
+      auto self = pybind11::cast<Derived *>(rawSelf);
+      Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
+      if (!PyErr_Occurred()) {
+        // Integer indexing.
+        return self->getItem(index).release().ptr();
+      }
+      PyErr_Clear();
+
+      // Assume slice-based indexing.
+      if (PySlice_Check(rawSubscript)) {
+        return self->getItemSlice(rawSubscript).release().ptr();
+      }
+
+      PyErr_SetString(PyExc_ValueError, "expected integer or slice");
+      return nullptr;
+    };
   }
 
   /// Hook for derived classes willing to bind more methods.

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index a5b5d3b60b6f2..7608bc56f0b74 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -14,6 +14,14 @@ def run(f):
   return f
 
 
+def expect_index_error(callback):
+  try:
+    _ = callback()
+    raise RuntimeError("Expected IndexError")
+  except IndexError:
+    pass
+
+
 # Verify iterator based traversal of the op/region/block hierarchy.
 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
 @run
@@ -418,7 +426,9 @@ def testOperationResultList():
   for t in call.results.types:
     print(f"Result type {t}")
 
-
+  # Out of range
+  expect_index_error(lambda: call.results[3])
+  expect_index_error(lambda: call.results[-4])
 
 
 # CHECK-LABEL: TEST: testOperationResultListSlice
@@ -470,8 +480,6 @@ def testOperationResultListSlice():
       print(f"Result {res.result_number}, type {res.type}")
 
 
-
-
 # CHECK-LABEL: TEST: testOperationAttributes
 @run
 def testOperationAttributes():


        


More information about the Mlir-commits mailing list