[Mlir-commits] [mlir] 6c7e6b2 - [mlir] Support slicing for operands in results in Python bindings

Alex Zinenko llvmlistbot at llvm.org
Tue Nov 10 01:46:30 PST 2020


Author: Alex Zinenko
Date: 2020-11-10T10:46:21+01:00
New Revision: 6c7e6b2c9abddf8b657997053d140ca4554cafbb

URL: https://github.com/llvm/llvm-project/commit/6c7e6b2c9abddf8b657997053d140ca4554cafbb
DIFF: https://github.com/llvm/llvm-project/commit/6c7e6b2c9abddf8b657997053d140ca4554cafbb.diff

LOG: [mlir] Support slicing for operands in results in Python bindings

Slicing, that is element access with `[being:end:step]` structure, is
a common Python idiom for sequence-like containers. It is also necessary
to support custom accessor for operations with variadic operands and
results (an operation an return a slice of its operands that correspond
to the given variadic group).

Add generic utility to support slicing in Python bindings and use it
for operation operands and results.

Depends On D90923

Reviewed By: stellaraccident, mehdi_amini

Differential Revision: https://reviews.llvm.org/D90936

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/Bindings/Python/ir_operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index de557c475b55..78924d9ae2c5 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -492,6 +492,9 @@ mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
 /// Returns whether the value is null.
 static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; }
 
+/// Returns 1 if two values are equal, 0 otherwise.
+int mlirValueEqual(MlirValue value1, MlirValue value2);
+
 /// Returns 1 if the value is a block argument, 0 otherwise.
 MLIR_CAPI_EXPORTED int mlirValueIsABlockArgument(MlirValue value);
 

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 2e02d775de3d..24b3da2b821f 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1213,34 +1213,33 @@ class PyBlockArgumentList {
   MlirBlock block;
 };
 
-/// A list of operation results. Internally, these are stored as consecutive
+/// A list of operation operands. Internally, these are stored as consecutive
 /// elements, random access is cheap. The result list is associated with the
 /// operation whose results these are, and extends the lifetime of this
 /// operation.
-class PyOpOperandList {
+class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
 public:
-  PyOpOperandList(PyOperationRef operation) : operation(operation) {}
+  static constexpr const char *pyClassName = "OpOperandList";
 
-  /// Returns the length of the result list.
-  intptr_t dunderLen() {
+  PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
+                  intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirOperationGetNumOperands(operation->get())
+                               : length,
+                  step),
+        operation(operation) {}
+
+  intptr_t getNumElements() {
     operation->checkValid();
     return mlirOperationGetNumOperands(operation->get());
   }
 
-  /// Returns `index`-th element in the result list.
-  PyValue dunderGetItem(intptr_t index) {
-    if (index < 0 || index >= dunderLen()) {
-      throw SetPyError(PyExc_IndexError,
-                       "attempt to access out of bounds region");
-    }
-    return PyValue(operation, mlirOperationGetOperand(operation->get(), index));
+  PyValue getElement(intptr_t pos) {
+    return PyValue(operation, mlirOperationGetOperand(operation->get(), pos));
   }
 
-  /// Defines a Python class in the bindings.
-  static void bind(py::module &m) {
-    py::class_<PyOpOperandList>(m, "OpOperandList")
-        .def("__len__", &PyOpOperandList::dunderLen)
-        .def("__getitem__", &PyOpOperandList::dunderGetItem);
+  PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyOpOperandList(operation, startIndex, length, step);
   }
 
 private:
@@ -1251,31 +1250,30 @@ class PyOpOperandList {
 /// elements, random access is cheap. The result list is associated with the
 /// operation whose results these are, and extends the lifetime of this
 /// operation.
-class PyOpResultList {
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
 public:
-  PyOpResultList(PyOperationRef operation) : operation(operation) {}
+  static constexpr const char *pyClassName = "OpResultList";
 
-  /// Returns the length of the result list.
-  intptr_t dunderLen() {
+  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+                 intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirOperationGetNumResults(operation->get())
+                               : length,
+                  step),
+        operation(operation) {}
+
+  intptr_t getNumElements() {
     operation->checkValid();
     return mlirOperationGetNumResults(operation->get());
   }
 
-  /// Returns `index`-th element in the result list.
-  PyOpResult dunderGetItem(intptr_t index) {
-    if (index < 0 || index >= dunderLen()) {
-      throw SetPyError(PyExc_IndexError,
-                       "attempt to access out of bounds region");
-    }
+  PyOpResult getElement(intptr_t index) {
     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
     return PyOpResult(value);
   }
 
-  /// Defines a Python class in the bindings.
-  static void bind(py::module &m) {
-    py::class_<PyOpResultList>(m, "OpResultList")
-        .def("__len__", &PyOpResultList::dunderLen)
-        .def("__getitem__", &PyOpResultList::dunderGetItem);
+  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyOpResultList(operation, startIndex, length, step);
   }
 
 private:
@@ -2932,6 +2930,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def(
           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
           kDumpDocstring)
+      .def("__eq__",
+           [](PyValue &self, PyValue &other) {
+             return self.get().ptr == other.get().ptr;
+           })
+      .def("__eq__", [](PyValue &self, py::object other) { return false; })
       .def(
           "__str__",
           [](PyValue &self) {

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 32697e59adba..3b24d6d962a0 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -185,6 +185,93 @@ struct PySinglePartStringAccumulator {
   bool invoked = false;
 };
 
+/// A CRTP base class for pseudo-containers willing to support Python-type
+/// slicing access on top of indexed access. Calling ::bind on this class
+/// will define `__len__` as well as `__getitem__` with integer and slice
+/// arguments.
+///
+/// This is intended for pseudo-containers that can refer to arbitrary slices of
+/// underlying storage indexed by a single integer. Indexing those with an
+/// integer produces an instance of ElementTy. Indexing those with a slice
+/// produces a new instance of Derived, which can be sliced further.
+///
+/// A derived class must provide the following:
+///   - a `static const char *pyClassName ` field containing the name of the
+///     Python class to bind;
+///   - an instance method `intptr_t getNumElements()` that returns the number
+///     of elements in the backing container (NOT that of the slice);
+///   - an instance method `ElementTy getElement(intptr_t)` that returns a
+///     single element at the given index.
+///   - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
+///     constructs a new instance of the derived pseudo-container with the
+///     given slice parameters (to be forwarded to the Sliceable constructor).
+///
+/// A derived class may additionally define:
+///   - a `static void bindDerived(ClassTy &)` method to bind additional methods
+///     the python class.
+template <typename Derived, typename ElementTy>
+class Sliceable {
+protected:
+  using ClassTy = pybind11::class_<Derived>;
+
+public:
+  explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
+      : startIndex(startIndex), length(length), step(step) {
+    assert(length >= 0 && "expected non-negative slice length");
+  }
+
+  /// Returns the length of the slice.
+  intptr_t dunderLen() const { return length; }
+
+  /// Returns the element at the given slice index. Supports negative indices
+  /// by taking elements in inverse order. Throws if the index is out of bounds.
+  ElementTy dunderGetItem(intptr_t index) {
+    // Negative indices mean we count from the end.
+    if (index < 0)
+      index = length + index;
+    if (index < 0 || index >= length) {
+      throw python::SetPyError(PyExc_IndexError,
+                               "attempt to access out of bounds");
+    }
+
+    // Compute the linear index given the current slice properties.
+    int linearIndex = index * step + startIndex;
+    assert(linearIndex >= 0 &&
+           linearIndex < static_cast<Derived *>(this)->getNumElements() &&
+           "linear index out of bounds, the slice is ill-formed");
+    return static_cast<Derived *>(this)->getElement(linearIndex);
+  }
+
+  /// Returns a new instance of the pseudo-container restricted to the given
+  /// slice.
+  Derived dunderGetItemSlice(pybind11::slice slice) {
+    ssize_t start, stop, extraStep, sliceLength;
+    if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) {
+      throw python::SetPyError(PyExc_IndexError,
+                               "attempt to access out of bounds");
+    }
+    return static_cast<Derived *>(this)->slice(startIndex + start * step,
+                                               sliceLength, step * extraStep);
+  }
+
+  /// Binds the indexing and length methods in the Python class.
+  static void bind(pybind11::module &m) {
+    auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName)
+                     .def("__len__", &Sliceable::dunderLen)
+                     .def("__getitem__", &Sliceable::dunderGetItem)
+                     .def("__getitem__", &Sliceable::dunderGetItemSlice);
+    Derived::bindDerived(clazz);
+  }
+
+  /// Hook for derived classes willing to bind more methods.
+  static void bindDerived(ClassTy &) {}
+
+private:
+  intptr_t startIndex;
+  intptr_t length;
+  intptr_t step;
+};
+
 } // namespace mlir
 
 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 9ad1a4c9a3a6..d8ff10f04b7a 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -479,6 +479,10 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
 // Value API.
 //===----------------------------------------------------------------------===//
 
+int mlirValueEqual(MlirValue value1, MlirValue value2) {
+  return unwrap(value1) == unwrap(value2);
+}
+
 int mlirValueIsABlockArgument(MlirValue value) {
   return unwrap(value).isa<BlockArgument>();
 }

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 54bc428ce8ae..8827e5dafe2c 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -155,6 +155,64 @@ def testOperationOperands():
 run(testOperationOperands)
 
 
+# CHECK-LABEL: TEST: testOperationOperandsSlice
+def testOperationOperandsSlice():
+  with Context() as ctx:
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(r"""
+      func @f1() {
+        %0 = "test.producer0"() : () -> i64
+        %1 = "test.producer1"() : () -> i64
+        %2 = "test.producer2"() : () -> i64
+        %3 = "test.producer3"() : () -> i64
+        %4 = "test.producer4"() : () -> i64
+        "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
+        return
+      }""")
+    func = module.body.operations[0]
+    entry_block = func.regions[0].blocks[0]
+    consumer = entry_block.operations[5]
+    assert len(consumer.operands) == 5
+    for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
+      assert left == right
+
+    # CHECK: test.producer0
+    # CHECK: test.producer1
+    # CHECK: test.producer2
+    # CHECK: test.producer3
+    # CHECK: test.producer4
+    full_slice = consumer.operands[:]
+    for operand in full_slice:
+      print(operand)
+
+    # CHECK: test.producer0
+    # CHECK: test.producer1
+    first_two = consumer.operands[0:2]
+    for operand in first_two:
+      print(operand)
+
+    # CHECK: test.producer3
+    # CHECK: test.producer4
+    last_two = consumer.operands[3:]
+    for operand in last_two:
+      print(operand)
+
+    # CHECK: test.producer0
+    # CHECK: test.producer2
+    # CHECK: test.producer4
+    even = consumer.operands[::2]
+    for operand in even:
+      print(operand)
+
+    # CHECK: test.producer2
+    fourth = consumer.operands[::2][1::2]
+    for operand in fourth:
+      print(operand)
+
+
+run(testOperationOperandsSlice)
+
+
 # CHECK-LABEL: TEST: testDetachedOperation
 def testDetachedOperation():
   ctx = Context()
@@ -277,6 +335,57 @@ def testOperationResultList():
 run(testOperationResultList)
 
 
+# CHECK-LABEL: TEST: testOperationResultListSlice
+def testOperationResultListSlice():
+  with Context() as ctx:
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(r"""
+      func @f1() {
+        "some.op"() : () -> (i1, i2, i3, i4, i5)
+        return
+      }
+    """)
+    func = module.body.operations[0]
+    entry_block = func.regions[0].blocks[0]
+    producer = entry_block.operations[0]
+
+    assert len(producer.results) == 5
+    for left, right in zip(producer.results, producer.results[::-1][::-1]):
+      assert left == right
+      assert left.result_number == right.result_number
+
+    # CHECK: Result 0, type i1
+    # CHECK: Result 1, type i2
+    # CHECK: Result 2, type i3
+    # CHECK: Result 3, type i4
+    # CHECK: Result 4, type i5
+    full_slice = producer.results[:]
+    for res in full_slice:
+      print(f"Result {res.result_number}, type {res.type}")
+
+    # CHECK: Result 1, type i2
+    # CHECK: Result 2, type i3
+    # CHECK: Result 3, type i4
+    middle = producer.results[1:4]
+    for res in middle:
+      print(f"Result {res.result_number}, type {res.type}")
+
+    # CHECK: Result 1, type i2
+    # CHECK: Result 3, type i4
+    odd = producer.results[1::2]
+    for res in odd:
+      print(f"Result {res.result_number}, type {res.type}")
+
+    # CHECK: Result 3, type i4
+    # CHECK: Result 1, type i2
+    inverted_middle = producer.results[-2:0:-2]
+    for res in inverted_middle:
+      print(f"Result {res.result_number}, type {res.type}")
+
+
+run(testOperationResultListSlice)
+
+
 # CHECK-LABEL: TEST: testOperationAttributes
 def testOperationAttributes():
   ctx = Context()


        


More information about the Mlir-commits mailing list