[Mlir-commits] [mlir] afeda4b - [mlir][python] provide access to function argument/result attributes

Alex Zinenko llvmlistbot at llvm.org
Thu Sep 30 00:38:22 PDT 2021


Author: Alex Zinenko
Date: 2021-09-30T09:38:13+02:00
New Revision: afeda4b9ed881fb3f1b4340b8afa9963656013dc

URL: https://github.com/llvm/llvm-project/commit/afeda4b9ed881fb3f1b4340b8afa9963656013dc
DIFF: https://github.com/llvm/llvm-project/commit/afeda4b9ed881fb3f1b4340b8afa9963656013dc.diff

LOG: [mlir][python] provide access to function argument/result attributes

Without this change, these attributes can only be accessed through the generic
operation attribute dictionary provided the caller knows the special operation
attribute names used for this purpose. Add some Python wrapping to support this
use case.

Also provide access to function arguments usable inside the function along with
a couple of quality-of-life improvements in using block arguments (function
arguments being the arguments of its entry block).

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110758

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/python/mlir/dialects/_builtin_ops_ext.py
    mlir/test/python/dialects/builtin.py
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 473c94c900c01..0434ac37c6881 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1594,32 +1594,35 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
 /// elements, random access is cheap. The argument list is associated with the
 /// operation that contains the block (detached blocks are not allowed in
 /// Python bindings) and extends its lifetime.
-class PyBlockArgumentList {
+class PyBlockArgumentList
+    : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
 public:
-  PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
-      : operation(std::move(operation)), block(block) {}
+  static constexpr const char *pyClassName = "BlockArgumentList";
 
-  /// Returns the length of the block argument list.
-  intptr_t dunderLen() {
+  PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
+                      intptr_t startIndex = 0, intptr_t length = -1,
+                      intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirBlockGetNumArguments(block) : length,
+                  step),
+        operation(std::move(operation)), block(block) {}
+
+  /// Returns the number of arguments in the list.
+  intptr_t getNumElements() {
     operation->checkValid();
     return mlirBlockGetNumArguments(block);
   }
 
-  /// Returns `index`-th element of the block argument list.
-  PyBlockArgument dunderGetItem(intptr_t index) {
-    if (index < 0 || index >= dunderLen()) {
-      throw SetPyError(PyExc_IndexError,
-                       "attempt to access out of bounds region");
-    }
-    PyValue value(operation, mlirBlockGetArgument(block, index));
-    return PyBlockArgument(value);
+  /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
+  PyBlockArgument getElement(intptr_t pos) {
+    MlirValue argument = mlirBlockGetArgument(block, pos);
+    return PyBlockArgument(operation, argument);
   }
 
-  /// Defines a Python class in the bindings.
-  static void bind(py::module &m) {
-    py::class_<PyBlockArgumentList>(m, "BlockArgumentList", py::module_local())
-        .def("__len__", &PyBlockArgumentList::dunderLen)
-        .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
+  /// Returns a sublist of this list.
+  PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
+                            intptr_t step) {
+    return PyBlockArgumentList(operation, block, startIndex, length, step);
   }
 
 private:

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 9fecc7ce37af0..2fdcf695b09bd 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -260,13 +260,29 @@ class Sliceable {
                                                sliceLength, step * extraStep);
   }
 
+  /// Returns a new vector (mapped to Python list) containing elements from two
+  /// slices. The new vector is necessary because slices may not be contiguous
+  /// or even come from the same original sequence.
+  std::vector<ElementTy> dunderAdd(Derived &other) {
+    std::vector<ElementTy> elements;
+    elements.reserve(length + other.length);
+    for (intptr_t i = 0; i < length; ++i) {
+      elements.push_back(dunderGetItem(i));
+    }
+    for (intptr_t i = 0; i < other.length; ++i) {
+      elements.push_back(other.dunderGetItem(i));
+    }
+    return elements;
+  }
+
   /// Binds the indexing and length methods in the Python class.
   static void bind(pybind11::module &m) {
     auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
                                            pybind11::module_local())
                      .def("__len__", &Sliceable::dunderLen)
                      .def("__getitem__", &Sliceable::dunderGetItem)
-                     .def("__getitem__", &Sliceable::dunderGetItemSlice);
+                     .def("__getitem__", &Sliceable::dunderGetItemSlice)
+                     .def("__add__", &Sliceable::dunderAdd);
     Derived::bindDerived(clazz);
   }
 

diff  --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py
index 99783d8333753..d464819f2dc8a 100644
--- a/mlir/python/mlir/dialects/_builtin_ops_ext.py
+++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py
@@ -11,6 +11,8 @@
 except ImportError as e:
   raise RuntimeError("Error loading imports from extension module") from e
 
+ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
+RESULT_ATTRIBUTE_NAME = "res_attrs"
 
 class ModuleOp:
   """Specialization for the module op class."""
@@ -100,6 +102,26 @@ def add_entry_block(self):
     self.body.blocks.append(*self.type.inputs)
     return self.body.blocks[0]
 
+  @property
+  def arg_attrs(self):
+    return self.attributes[ARGUMENT_ATTRIBUTE_NAME]
+
+  @arg_attrs.setter
+  def arg_attrs(self, attribute: ArrayAttr):
+    self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+
+  @property
+  def arguments(self):
+    return self.entry_block.arguments
+
+  @property
+  def result_attrs(self):
+    return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+  @result_attrs.setter
+  def result_attrs(self, attribute: ArrayAttr):
+    self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+
   @classmethod
   def from_py_func(FuncOp,
                    *inputs: Type,

diff  --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 1f4847dce892c..b87eabb72b94a 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -161,3 +161,41 @@ def testBuildFuncOp():
   # CHECK:   return %arg0 : tensor<2x3x4xf32>
   # CHECK:  }
   print(m)
+
+
+# CHECK-LABEL: TEST: testFuncArgumentAccess
+ at run
+def testFuncArgumentAccess():
+  with Context(), Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    f64 = F64Type.get()
+    with InsertionPoint(module.body):
+      func = builtin.FuncOp("some_func", ([f32, f32], [f64, f64]))
+      with InsertionPoint(func.add_entry_block()):
+        std.ReturnOp(func.arguments)
+      func.arg_attrs = ArrayAttr.get([
+          DictAttr.get({
+              "foo": StringAttr.get("bar"),
+              "baz": UnitAttr.get()
+          }),
+          DictAttr.get({"qux": ArrayAttr.get([])})
+      ])
+      func.result_attrs = ArrayAttr.get([
+          DictAttr.get({"res1": FloatAttr.get(f32, 42.0)}),
+          DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
+      ])
+
+  # CHECK: [{baz, foo = "bar"}, {qux = []}]
+  print(func.arg_attrs)
+
+  # CHECK: [{res1 = 4.200000e+01 : f32}, {res2 = 2.560000e+02 : f64}]
+  print(func.result_attrs)
+
+  # CHECK: func @some_func(
+  # CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
+  # CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
+  # CHECK: f64 {res1 = 4.200000e+01 : f32},
+  # CHECK: f64 {res2 = 2.560000e+02 : f64})
+  # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
+  print(module)

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 5f510fb6b2cbf..d04f52f0c0fe9 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -134,6 +134,17 @@ def testBlockArgumentList():
     for arg in entry_block.arguments:
       print(f"Argument {arg.arg_number}, type {arg.type}")
 
+    # Check that slicing works for block argument lists.
+    # CHECK: Argument 1, type i16
+    # CHECK: Argument 2, type i24
+    for arg in entry_block.arguments[1:]:
+      print(f"Argument {arg.arg_number}, type {arg.type}")
+
+    # Check that we can concatenate slices of argument lists.
+    # CHECK: Length: 4
+    print("Length: ",
+          len(entry_block.arguments[:2] + entry_block.arguments[1:]))
+
 
 run(testBlockArgumentList)
 
@@ -598,22 +609,24 @@ def testCreateWithInvalidAttributes():
   ctx = Context()
   with Location.unknown(ctx):
     try:
-      Operation.create("builtin.module", attributes={None:StringAttr.get("name")})
+      Operation.create(
+          "builtin.module", attributes={None: StringAttr.get("name")})
     except Exception as e:
       # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
       print(e)
     try:
-      Operation.create("builtin.module", attributes={42:StringAttr.get("name")})
+      Operation.create(
+          "builtin.module", attributes={42: StringAttr.get("name")})
     except Exception as e:
       # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
       print(e)
     try:
-      Operation.create("builtin.module", attributes={"some_key":ctx})
+      Operation.create("builtin.module", attributes={"some_key": ctx})
     except Exception as e:
       # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
       print(e)
     try:
-      Operation.create("builtin.module", attributes={"some_key":None})
+      Operation.create("builtin.module", attributes={"some_key": None})
     except Exception as e:
       # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
       print(e)


        


More information about the Mlir-commits mailing list