[Mlir-commits] [mlir] ee168fb - [mlir][python] Fix issues with block argument slices

Alex Zinenko llvmlistbot at llvm.org
Thu Jul 21 07:41:17 PDT 2022


Author: Alex Zinenko
Date: 2022-07-21T14:41:12Z
New Revision: ee168fb90e4f3321037dcc2c8cb82497a70db92e

URL: https://github.com/llvm/llvm-project/commit/ee168fb90e4f3321037dcc2c8cb82497a70db92e
DIFF: https://github.com/llvm/llvm-project/commit/ee168fb90e4f3321037dcc2c8cb82497a70db92e.diff

LOG: [mlir][python] Fix issues with block argument slices

The type extraction helper function for block argument and op result
list objects was ignoring the slice entirely. So was the slice addition.
Both are caused by a misleading naming convention to implement slices
via CRTP. Make the convention more explicit and hide the helper
functions so users have harder time calling them directly.

Closes #56540.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D130271

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRAffine.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 0da936e85bc38..fc7133b437aa2 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -385,9 +385,13 @@ class PyAffineMapExprList
                   step),
         affineMap(map) {}
 
-  intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
+
+  intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
 
-  PyAffineExpr getElement(intptr_t pos) {
+  PyAffineExpr getRawElement(intptr_t pos) {
     return PyAffineExpr(affineMap.getContext(),
                         mlirAffineMapGetResult(affineMap, pos));
   }
@@ -397,7 +401,6 @@ class PyAffineMapExprList
     return PyAffineMapExprList(affineMap, startIndex, length, step);
   }
 
-private:
   PyAffineMap affineMap;
 };
 } // namespace
@@ -460,9 +463,13 @@ class PyIntegerSetConstraintList
                   step),
         set(set) {}
 
-  intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
+
+  intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
 
-  PyIntegerSetConstraint getElement(intptr_t pos) {
+  PyIntegerSetConstraint getRawElement(intptr_t pos) {
     return PyIntegerSetConstraint(set, pos);
   }
 
@@ -471,7 +478,6 @@ class PyIntegerSetConstraintList
     return PyIntegerSetConstraintList(set, startIndex, length, step);
   }
 
-private:
   PyIntegerSet set;
 };
 } // namespace

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 9738351824c99..fea26ec661a60 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1968,8 +1968,8 @@ template <typename Container>
 static std::vector<PyType> getValueTypes(Container &container,
                                          PyMlirContextRef &context) {
   std::vector<PyType> result;
-  result.reserve(container.getNumElements());
-  for (int i = 0, e = container.getNumElements(); i < e; ++i) {
+  result.reserve(container.size());
+  for (int i = 0, e = container.size(); i < e; ++i) {
     result.push_back(
         PyType(context, mlirValueGetType(container.getElement(i).get())));
   }
@@ -1993,14 +1993,24 @@ class PyBlockArgumentList
                   step),
         operation(std::move(operation)), block(block) {}
 
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly("types", [](PyBlockArgumentList &self) {
+      return getValueTypes(self, self.operation->getContext());
+    });
+  }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
   /// Returns the number of arguments in the list.
-  intptr_t getNumElements() {
+  intptr_t getRawNumElements() {
     operation->checkValid();
     return mlirBlockGetNumArguments(block);
   }
 
-  /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
-  PyBlockArgument getElement(intptr_t pos) {
+  /// Returns `pos`-the element in the list.
+  PyBlockArgument getRawElement(intptr_t pos) {
     MlirValue argument = mlirBlockGetArgument(block, pos);
     return PyBlockArgument(operation, argument);
   }
@@ -2011,13 +2021,6 @@ class PyBlockArgumentList
     return PyBlockArgumentList(operation, block, startIndex, length, step);
   }
 
-  static void bindDerived(ClassTy &c) {
-    c.def_property_readonly("types", [](PyBlockArgumentList &self) {
-      return getValueTypes(self, self.operation->getContext());
-    });
-  }
-
-private:
   PyOperationRef operation;
   MlirBlock block;
 };
@@ -2038,12 +2041,25 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
                   step),
         operation(operation) {}
 
-  intptr_t getNumElements() {
+  void dunderSetItem(intptr_t index, PyValue value) {
+    index = wrapIndex(index);
+    mlirOperationSetOperand(operation->get(), index, value.get());
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__setitem__", &PyOpOperandList::dunderSetItem);
+  }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyOpOperandList, PyValue>;
+
+  intptr_t getRawNumElements() {
     operation->checkValid();
     return mlirOperationGetNumOperands(operation->get());
   }
 
-  PyValue getElement(intptr_t pos) {
+  PyValue getRawElement(intptr_t pos) {
     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
     MlirOperation owner;
     if (mlirValueIsAOpResult(operand))
@@ -2061,16 +2077,6 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
     return PyOpOperandList(operation, startIndex, length, step);
   }
 
-  void dunderSetItem(intptr_t index, PyValue value) {
-    index = wrapIndex(index);
-    mlirOperationSetOperand(operation->get(), index, value.get());
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def("__setitem__", &PyOpOperandList::dunderSetItem);
-  }
-
-private:
   PyOperationRef operation;
 };
 
@@ -2090,12 +2096,22 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
                   step),
         operation(operation) {}
 
-  intptr_t getNumElements() {
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly("types", [](PyOpResultList &self) {
+      return getValueTypes(self, self.operation->getContext());
+    });
+  }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyOpResultList, PyOpResult>;
+
+  intptr_t getRawNumElements() {
     operation->checkValid();
     return mlirOperationGetNumResults(operation->get());
   }
 
-  PyOpResult getElement(intptr_t index) {
+  PyOpResult getRawElement(intptr_t index) {
     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
     return PyOpResult(value);
   }
@@ -2104,13 +2120,6 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
     return PyOpResultList(operation, startIndex, length, step);
   }
 
-  static void bindDerived(ClassTy &c) {
-    c.def_property_readonly("types", [](PyOpResultList &self) {
-      return getValueTypes(self, self.operation->getContext());
-    });
-  }
-
-private:
   PyOperationRef operation;
 };
 

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index e791ba8e214c3..5356cbd54ff48 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -199,15 +199,17 @@ struct PySinglePartStringAccumulator {
 /// A derived class must provide the following:
 ///   - a `static const char *pyClassName ` field containing the name of the
 ///     Python class to bind;
-///   - an instance method `intptr_t getNumElements()` that returns the number
+///   - an instance method `intptr_t getRawNumElements()` that returns the
+///   number
 ///     of elements in the backing container (NOT that of the slice);
-///   - an instance method `ElementTy getElement(intptr_t)` that returns a
-///     single element at the given index.
+///   - an instance method `ElementTy getRawElement(intptr_t)` that returns a
+///     single element at the given linear index (NOT slice index);
 ///   - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
 ///     constructs a new instance of the derived pseudo-container with the
 ///     given slice parameters (to be forwarded to the Sliceable constructor).
 ///
-/// The getNumElements() and getElement(intptr_t) callbacks must not throw.
+/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
+/// throw.
 ///
 /// A derived class may additionally define:
 ///   - a `static void bindDerived(ClassTy &)` method to bind additional methods
@@ -217,8 +219,8 @@ class Sliceable {
 protected:
   using ClassTy = pybind11::class_<Derived>;
 
-  // Transforms `index` into a legal value to access the underlying sequence.
-  // Returns <0 on failure.
+  /// Transforms `index` into a legal value to access the underlying sequence.
+  /// Returns <0 on failure.
   intptr_t wrapIndex(intptr_t index) {
     if (index < 0)
       index = length + index;
@@ -227,6 +229,15 @@ class Sliceable {
     return index;
   }
 
+  /// Computes the linear index given the current slice properties.
+  intptr_t linearizeIndex(intptr_t index) {
+    intptr_t linearIndex = index * step + startIndex;
+    assert(linearIndex >= 0 &&
+           linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
+           "linear index out of bounds, the slice is ill-formed");
+    return linearIndex;
+  }
+
   /// Returns the element at the given slice index. Supports negative indices
   /// by taking elements in inverse order. Returns a nullptr object if out
   /// of bounds.
@@ -238,13 +249,8 @@ class Sliceable {
       return {};
     }
 
-    // Compute the linear index given the current slice properties.
-    int linearIndex = index * step + startIndex;
-    assert(linearIndex >= 0 &&
-           linearIndex < static_cast<Derived *>(this)->getNumElements() &&
-           "linear index out of bounds, the slice is ill-formed");
     return pybind11::cast(
-        static_cast<Derived *>(this)->getElement(linearIndex));
+        static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
   }
 
   /// Returns a new instance of the pseudo-container restricted to the given
@@ -266,6 +272,21 @@ class Sliceable {
     assert(length >= 0 && "expected non-negative slice length");
   }
 
+  /// Returns the `index`-th element in the slice, supports negative indices.
+  /// Throws if the index is out of bounds.
+  ElementTy getElement(intptr_t index) {
+    // Negative indices mean we count from the end.
+    index = wrapIndex(index);
+    if (index < 0) {
+      throw pybind11::index_error("index out of range");
+    }
+
+    return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
+  }
+
+  /// Returns the size of slice.
+  intptr_t size() { return length; }
+
   /// Returns a new vector (mapped to Python list) containing elements from two
   /// slices. The new vector is necessary because slices may not be contiguous
   /// or even come from the same original sequence.
@@ -276,7 +297,7 @@ class Sliceable {
       elements.push_back(static_cast<Derived *>(this)->getElement(i));
     }
     for (intptr_t i = 0; i < other.length; ++i) {
-      elements.push_back(static_cast<Derived *>(this)->getElement(i));
+      elements.push_back(static_cast<Derived *>(&other)->getElement(i));
     }
     return elements;
   }

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index b7b47f8bf07c8..2d70b88a17b54 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -185,6 +185,19 @@ def testBlockArgumentList():
     for t in entry_block.arguments.types:
       print("Type: ", t)
 
+    # Check that slicing and type access compose.
+    # CHECK: Sliced type: i16
+    # CHECK: Sliced type: i24
+    for t in entry_block.arguments[1:].types:
+      print("Sliced type: ", t)
+
+    # Check that slice addition works as expected.
+    # CHECK: Argument 2, type i24
+    # CHECK: Argument 0, type i8
+    restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
+    for arg in restructured:
+      print(f"Argument {arg.arg_number}, type {arg.type}")
+
 
 # CHECK-LABEL: TEST: testOperationOperands
 @run


        


More information about the Mlir-commits mailing list