[llvm-branch-commits] [mlir] 547e3ee - [mlir] Expose MemRef layout in Python bindings

Alex Zinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jan 11 11:02:19 PST 2021


Author: Alex Zinenko
Date: 2021-01-11T19:57:16+01:00
New Revision: 547e3eef14a8e75a867dfcc6b45cd1f0547d4e07

URL: https://github.com/llvm/llvm-project/commit/547e3eef14a8e75a867dfcc6b45cd1f0547d4e07
DIFF: https://github.com/llvm/llvm-project/commit/547e3eef14a8e75a867dfcc6b45cd1f0547d4e07.diff

LOG: [mlir] Expose MemRef layout in Python bindings

This wasn't possible before because there was no support for affine expressions
as maps. Now that this support is available, provide the mechanism for
constructing maps with a layout and inspecting it.

Rework the `get` method on MemRefType in Python to avoid needing an explicit
memory space or layout map. Remove the `get_num_maps`, it is too low-level,
using the length of the now-avaiable pseudo-list of layout maps is more
pythonic.

Depends On D94297

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D94302

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 77898be41565..9712d58ad87a 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -225,7 +225,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type);
  * same context as element type. The type is owned by the context. */
 MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(
     MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps,
-    MlirAttribute const *affineMaps, unsigned memorySpace);
+    MlirAffineMap const *affineMaps, unsigned memorySpace);
+
+/** Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o
+ * illegal arguments, emitting appropriate diagnostics. */
+MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(
+    MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps,
+    MlirAffineMap const *affineMaps, unsigned memorySpace, MlirLocation loc);
 
 /** Creates a MemRef type with the given rank, shape, memory space and element
  * type in the same context as the element type. The type has no affine maps,

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 81f84b8152f4..218099bedc6f 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -2535,6 +2535,8 @@ class PyUnrankedTensorType
   }
 };
 
+class PyMemRefLayoutMapList;
+
 /// Ranked MemRef Type subclass - MemRefType.
 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
 public:
@@ -2542,16 +2544,22 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
   static constexpr const char *pyClassName = "MemRefType";
   using PyConcreteType::PyConcreteType;
 
+  PyMemRefLayoutMapList getLayout();
+
   static void bindDerived(ClassTy &c) {
-    // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
-    // once the affine map binding is completed.
     c.def_static(
-         "get_contiguous_memref",
-         // TODO: Make the location optional and create a default location.
+         "get",
          [](PyType &elementType, std::vector<int64_t> shape,
-            unsigned memorySpace, DefaultingPyLocation loc) {
-           MlirType t = mlirMemRefTypeContiguousGetChecked(
-               elementType, shape.size(), shape.data(), memorySpace, loc);
+            std::vector<PyAffineMap> layout, unsigned memorySpace,
+            DefaultingPyLocation loc) {
+           SmallVector<MlirAffineMap> maps;
+           maps.reserve(layout.size());
+           for (PyAffineMap &map : layout)
+             maps.push_back(map);
+
+           MlirType t = mlirMemRefTypeGetChecked(elementType, shape.size(),
+                                                 shape.data(), maps.size(),
+                                                 maps.data(), memorySpace, loc);
            // TODO: Rework error reporting once diagnostic engine is exposed
            // in C API.
            if (mlirTypeIsNull(t)) {
@@ -2565,15 +2573,11 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
            }
            return PyMemRefType(elementType.getContext(), t);
          },
-         py::arg("element_type"), py::arg("shape"), py::arg("memory_space"),
+         py::arg("element_type"), py::arg("shape"),
+         py::arg("layout") = py::list(), py::arg("memory_space") = 0,
          py::arg("loc") = py::none(), "Create a memref type")
-        .def_property_readonly(
-            "num_affine_maps",
-            [](PyMemRefType &self) -> intptr_t {
-              return mlirMemRefTypeGetNumAffineMaps(self);
-            },
-            "Returns the number of affine layout maps in the given MemRef "
-            "type.")
+        .def_property_readonly("layout", &PyMemRefType::getLayout,
+                               "The list of layout maps of the MemRef type.")
         .def_property_readonly(
             "memory_space",
             [](PyMemRefType &self) -> unsigned {
@@ -2583,6 +2587,41 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
   }
 };
 
+/// A list of affine layout maps in a memref type. Internally, these are stored
+/// as consecutive elements, random access is cheap. Both the type and the maps
+/// are owned by the context, no need to worry about lifetime extension.
+class PyMemRefLayoutMapList
+    : public Sliceable<PyMemRefLayoutMapList, PyAffineMap> {
+public:
+  static constexpr const char *pyClassName = "MemRefLayoutMapList";
+
+  PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0,
+                        intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length,
+                  step),
+        memref(type) {}
+
+  intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); }
+
+  PyAffineMap getElement(intptr_t index) {
+    return PyAffineMap(memref.getContext(),
+                       mlirMemRefTypeGetAffineMap(memref, index));
+  }
+
+  PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length,
+                              intptr_t step) {
+    return PyMemRefLayoutMapList(memref, startIndex, length, step);
+  }
+
+private:
+  PyMemRefType memref;
+};
+
+PyMemRefLayoutMapList PyMemRefType::getLayout() {
+  return PyMemRefLayoutMapList(*this);
+}
+
 /// Unranked MemRef Type subclass - UnrankedMemRefType.
 class PyUnrankedMemRefType
     : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
@@ -3631,6 +3670,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyRankedTensorType::bind(m);
   PyUnrankedTensorType::bind(m);
   PyMemRefType::bind(m);
+  PyMemRefLayoutMapList::bind(m);
   PyUnrankedMemRefType::bind(m);
   PyTupleType::bind(m);
   PyFunctionType::bind(m);

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index f0c2901f5e10..2de2fa1afde2 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -231,6 +231,17 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
                       unwrap(elementType), maps, memorySpace));
 }
 
+MlirType mlirMemRefTypeGetChecked(MlirType elementType, intptr_t rank,
+                                  const int64_t *shape, intptr_t numMaps,
+                                  MlirAffineMap const *affineMaps,
+                                  unsigned memorySpace, MlirLocation loc) {
+  SmallVector<AffineMap, 1> maps;
+  (void)unwrapList(numMaps, affineMaps, maps);
+  return wrap(MemRefType::getChecked(
+      unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+      unwrap(elementType), maps, memorySpace));
+}
+
 MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
                                      const int64_t *shape,
                                      unsigned memorySpace) {

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index a65095c4289e..64b684ee99e9 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -326,17 +326,27 @@ def testMemRefType():
     f32 = F32Type.get()
     shape = [2, 3]
     loc = Location.unknown()
-    memref = MemRefType.get_contiguous_memref(f32, shape, 2)
+    memref = MemRefType.get(f32, shape, memory_space=2)
     # CHECK: memref type: memref<2x3xf32, 2>
     print("memref type:", memref)
     # CHECK: number of affine layout maps: 0
-    print("number of affine layout maps:", memref.num_affine_maps)
+    print("number of affine layout maps:", len(memref.layout))
     # CHECK: memory space: 2
     print("memory space:", memref.memory_space)
 
+    layout = AffineMap.get_permutation([1, 0])
+    memref_layout = MemRefType.get(f32, shape, [layout])
+    # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
+    print("memref type:", memref_layout)
+    assert len(memref_layout.layout) == 1
+    # CHECK: memref layout: (d0, d1) -> (d1, d0)
+    print("memref layout:", memref_layout.layout[0])
+    # CHECK: memory space: 0
+    print("memory space:", memref_layout.memory_space)
+
     none = NoneType.get()
     try:
-      memref_invalid = MemRefType.get_contiguous_memref(none, shape, 2)
+      memref_invalid = MemRefType.get(none, shape)
     except ValueError as e:
       # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
       # CHECK: or complex type.


        


More information about the llvm-branch-commits mailing list