[Mlir-commits] [mlir] e79bd0b - [mlir] More Python bindings for AffineMap

Alex Zinenko llvmlistbot at llvm.org
Mon Jan 11 10:57:28 PST 2021


Author: Alex Zinenko
Date: 2021-01-11T19:57:15+01:00
New Revision: e79bd0b4f25e68130a2ac273d6508ea322028b61

URL: https://github.com/llvm/llvm-project/commit/e79bd0b4f25e68130a2ac273d6508ea322028b61
DIFF: https://github.com/llvm/llvm-project/commit/e79bd0b4f25e68130a2ac273d6508ea322028b61.diff

LOG: [mlir] More Python bindings for AffineMap

Now that the bindings for AffineExpr have been added, add more bindings for
constructing and inspecting AffineMap that consists of AffineExprs.

Depends On D94225

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D94297

Added: 
    

Modified: 
    mlir/include/mlir-c/AffineExpr.h
    mlir/include/mlir-c/AffineMap.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/CAPI/IR/AffineMap.cpp
    mlir/test/Bindings/Python/ir_affine_map.py
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h
index 2eb8ae03e03d..ec445682c011 100644
--- a/mlir/include/mlir-c/AffineExpr.h
+++ b/mlir/include/mlir-c/AffineExpr.h
@@ -10,7 +10,6 @@
 #ifndef MLIR_C_AFFINEEXPR_H
 #define MLIR_C_AFFINEEXPR_H
 
-#include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
 
 #ifdef __cplusplus

diff  --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h
index c52fe6826251..bf0c7c7b5381 100644
--- a/mlir/include/mlir-c/AffineMap.h
+++ b/mlir/include/mlir-c/AffineMap.h
@@ -10,6 +10,7 @@
 #ifndef MLIR_C_AFFINEMAP_H
 #define MLIR_C_AFFINEMAP_H
 
+#include "mlir-c/AffineExpr.h"
 #include "mlir-c/IR.h"
 
 #ifdef __cplusplus
@@ -67,9 +68,18 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapEmptyGet(MlirContext ctx);
 
 /** Creates a zero result affine map of the given dimensions and symbols in the
  * context. The affine map is owned by the context. */
+MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapZeroResultGet(
+    MlirContext ctx, intptr_t dimCount, intptr_t symbolCount);
+
+/** Creates an affine map with results defined by the given list of affine
+ * expressions. The map resulting map also has the requested number of input
+ * dimensions and symbols, regardless of them being used in the results.
+ */
 MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapGet(MlirContext ctx,
                                                   intptr_t dimCount,
-                                                  intptr_t symbolCount);
+                                                  intptr_t symbolCount,
+                                                  intptr_t nAffineExprs,
+                                                  MlirAffineExpr *affineExprs);
 
 /** Creates a single constant result affine map in the context. The affine map
  * is owned by the context. */
@@ -124,6 +134,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirAffineMapGetNumSymbols(MlirAffineMap affineMap);
 /// Returns the number of results of the given affine map.
 MLIR_CAPI_EXPORTED intptr_t mlirAffineMapGetNumResults(MlirAffineMap affineMap);
 
+/// Returns the result at the given position.
+MLIR_CAPI_EXPORTED MlirAffineExpr
+mlirAffineMapGetResult(MlirAffineMap affineMap, intptr_t pos);
+
 /** Returns the number of inputs (dimensions + symbols) of the given affine
  * map. */
 MLIR_CAPI_EXPORTED intptr_t mlirAffineMapGetNumInputs(MlirAffineMap affineMap);

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 2d18a7a488e7..81f84b8152f4 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -11,6 +11,7 @@
 #include "Globals.h"
 #include "PybindUtils.h"
 
+#include "mlir-c/AffineMap.h"
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
@@ -2943,9 +2944,43 @@ PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
 }
 
 //------------------------------------------------------------------------------
-// PyAffineMap.
+// PyAffineMap and utilities.
 //------------------------------------------------------------------------------
 
+namespace {
+/// A list of expressions contained in an affine map. Internally these are
+/// stored as a consecutive array leading to inexpensive random access. Both
+/// the map and the expression are owned by the context so we need not bother
+/// with lifetime extension.
+class PyAffineMapExprList
+    : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
+public:
+  static constexpr const char *pyClassName = "AffineExprList";
+
+  PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
+                      intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirAffineMapGetNumResults(map) : length,
+                  step),
+        affineMap(map) {}
+
+  intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
+
+  PyAffineExpr getElement(intptr_t pos) {
+    return PyAffineExpr(affineMap.getContext(),
+                        mlirAffineMapGetResult(affineMap, pos));
+  }
+
+  PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
+                            intptr_t step) {
+    return PyAffineMapExprList(affineMap, startIndex, length, step);
+  }
+
+private:
+  PyAffineMap affineMap;
+};
+} // end namespace
+
 bool PyAffineMap::operator==(const PyAffineMap &other) {
   return mlirAffineMapEqual(affineMap, other.affineMap);
 }
@@ -3741,6 +3776,72 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyAffineMap::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
+      .def("__eq__",
+           [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
+      .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
+      .def("__str__",
+           [](PyAffineMap &self) {
+             PyPrintAccumulator printAccum;
+             mlirAffineMapPrint(self, printAccum.getCallback(),
+                                printAccum.getUserData());
+             return printAccum.join();
+           })
+      .def("__repr__",
+           [](PyAffineMap &self) {
+             PyPrintAccumulator printAccum;
+             printAccum.parts.append("AffineMap(");
+             mlirAffineMapPrint(self, printAccum.getCallback(),
+                                printAccum.getUserData());
+             printAccum.parts.append(")");
+             return printAccum.join();
+           })
+      .def_property_readonly(
+          "context",
+          [](PyAffineMap &self) { return self.getContext().getObject(); },
+          "Context that owns the Affine Map")
+      .def(
+          "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
+          kDumpDocstring)
+      .def_static(
+          "get",
+          [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
+             DefaultingPyMlirContext context) {
+            SmallVector<MlirAffineExpr> affineExprs;
+            affineExprs.reserve(py::len(exprs));
+            for (py::handle expr : exprs) {
+              try {
+                affineExprs.push_back(expr.cast<PyAffineExpr>());
+              } catch (py::cast_error &err) {
+                std::string msg =
+                    std::string("Invalid expression when attempting to create "
+                                "an AffineMap (") +
+                    err.what() + ")";
+                throw py::cast_error(msg);
+              } catch (py::reference_cast_error &err) {
+                std::string msg =
+                    std::string("Invalid expression (None?) when attempting to "
+                                "create an AffineMap (") +
+                    err.what() + ")";
+                throw py::cast_error(msg);
+              }
+            }
+            MlirAffineMap map =
+                mlirAffineMapGet(context->get(), dimCount, symbolCount,
+                                 affineExprs.size(), affineExprs.data());
+            return PyAffineMap(context->getRef(), map);
+          },
+          py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
+          py::arg("context") = py::none(),
+          "Gets a map with the given expressions as results.")
+      .def_static(
+          "get_constant",
+          [](intptr_t value, DefaultingPyMlirContext context) {
+            MlirAffineMap affineMap =
+                mlirAffineMapConstantGet(context->get(), value);
+            return PyAffineMap(context->getRef(), affineMap);
+          },
+          py::arg("value"), py::arg("context") = py::none(),
+          "Gets an affine map with a single constant result")
       .def_static(
           "get_empty",
           [](DefaultingPyMlirContext context) {
@@ -3748,14 +3849,82 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             return PyAffineMap(context->getRef(), affineMap);
           },
           py::arg("context") = py::none(), "Gets an empty affine map.")
+      .def_static(
+          "get_identity",
+          [](intptr_t nDims, DefaultingPyMlirContext context) {
+            MlirAffineMap affineMap =
+                mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
+            return PyAffineMap(context->getRef(), affineMap);
+          },
+          py::arg("n_dims"), py::arg("context") = py::none(),
+          "Gets an identity map with the given number of dimensions.")
+      .def_static(
+          "get_minor_identity",
+          [](intptr_t nDims, intptr_t nResults,
+             DefaultingPyMlirContext context) {
+            MlirAffineMap affineMap =
+                mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
+            return PyAffineMap(context->getRef(), affineMap);
+          },
+          py::arg("n_dims"), py::arg("n_results"),
+          py::arg("context") = py::none(),
+          "Gets a minor identity map with the given number of dimensions and "
+          "results.")
+      .def_static(
+          "get_permutation",
+          [](std::vector<unsigned> permutation,
+             DefaultingPyMlirContext context) {
+            MlirAffineMap affineMap = mlirAffineMapPermutationGet(
+                context->get(), permutation.size(), permutation.data());
+            return PyAffineMap(context->getRef(), affineMap);
+          },
+          py::arg("permutation"), py::arg("context") = py::none(),
+          "Gets an affine map that permutes its inputs.")
+      .def("get_submap",
+           [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
+             intptr_t numResults = mlirAffineMapGetNumResults(self);
+             for (intptr_t pos : resultPos) {
+               if (pos < 0 || pos >= numResults)
+                 throw py::value_error("result position out of bounds");
+             }
+             MlirAffineMap affineMap = mlirAffineMapGetSubMap(
+                 self, resultPos.size(), resultPos.data());
+             return PyAffineMap(self.getContext(), affineMap);
+           })
+      .def("get_major_submap",
+           [](PyAffineMap &self, intptr_t nResults) {
+             if (nResults >= mlirAffineMapGetNumResults(self))
+               throw py::value_error("number of results out of bounds");
+             MlirAffineMap affineMap =
+                 mlirAffineMapGetMajorSubMap(self, nResults);
+             return PyAffineMap(self.getContext(), affineMap);
+           })
+      .def("get_minor_submap",
+           [](PyAffineMap &self, intptr_t nResults) {
+             if (nResults >= mlirAffineMapGetNumResults(self))
+               throw py::value_error("number of results out of bounds");
+             MlirAffineMap affineMap =
+                 mlirAffineMapGetMinorSubMap(self, nResults);
+             return PyAffineMap(self.getContext(), affineMap);
+           })
       .def_property_readonly(
-          "context",
-          [](PyAffineMap &self) { return self.getContext().getObject(); },
-          "Context that owns the Affine Map")
-      .def("__eq__",
-           [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
-      .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
-      .def(
-          "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
-          kDumpDocstring);
+          "is_permutation",
+          [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
+      .def_property_readonly("is_projected_permutation",
+                             [](PyAffineMap &self) {
+                               return mlirAffineMapIsProjectedPermutation(self);
+                             })
+      .def_property_readonly(
+          "n_dims",
+          [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
+      .def_property_readonly(
+          "n_inputs",
+          [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
+      .def_property_readonly(
+          "n_symbols",
+          [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
+      .def_property_readonly("results", [](PyAffineMap &self) {
+        return PyAffineMapExprList(self);
+      });
+  PyAffineMapExprList::bind(m);
 }

diff  --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp
index ac81586144f5..f532d5dae72e 100644
--- a/mlir/lib/CAPI/IR/AffineMap.cpp
+++ b/mlir/lib/CAPI/IR/AffineMap.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
+#include "mlir/CAPI/AffineExpr.h"
 #include "mlir/CAPI/AffineMap.h"
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Utils.h"
@@ -37,11 +38,19 @@ MlirAffineMap mlirAffineMapEmptyGet(MlirContext ctx) {
   return wrap(AffineMap::get(unwrap(ctx)));
 }
 
-MlirAffineMap mlirAffineMapGet(MlirContext ctx, intptr_t dimCount,
-                               intptr_t symbolCount) {
+MlirAffineMap mlirAffineMapZeroResultGet(MlirContext ctx, intptr_t dimCount,
+                                         intptr_t symbolCount) {
   return wrap(AffineMap::get(dimCount, symbolCount, unwrap(ctx)));
 }
 
+MlirAffineMap mlirAffineMapGet(MlirContext ctx, intptr_t dimCount,
+                               intptr_t symbolCount, intptr_t nAffineExprs,
+                               MlirAffineExpr *affineExprs) {
+  SmallVector<AffineExpr, 4> exprs;
+  ArrayRef<AffineExpr> exprList = unwrapList(nAffineExprs, affineExprs, exprs);
+  return wrap(AffineMap::get(dimCount, symbolCount, exprList, unwrap(ctx)));
+}
+
 MlirAffineMap mlirAffineMapConstantGet(MlirContext ctx, int64_t val) {
   return wrap(AffineMap::getConstantMap(val, unwrap(ctx)));
 }
@@ -94,6 +103,10 @@ intptr_t mlirAffineMapGetNumResults(MlirAffineMap affineMap) {
   return unwrap(affineMap).getNumResults();
 }
 
+MlirAffineExpr mlirAffineMapGetResult(MlirAffineMap affineMap, intptr_t pos) {
+  return wrap(unwrap(affineMap).getResult(static_cast<unsigned>(pos)));
+}
+
 intptr_t mlirAffineMapGetNumInputs(MlirAffineMap affineMap) {
   return unwrap(affineMap).getNumInputs();
 }

diff  --git a/mlir/test/Bindings/Python/ir_affine_map.py b/mlir/test/Bindings/Python/ir_affine_map.py
index f66826659233..fe37eb971555 100644
--- a/mlir/test/Bindings/Python/ir_affine_map.py
+++ b/mlir/test/Bindings/Python/ir_affine_map.py
@@ -22,3 +22,151 @@ def testAffineMapCapsule():
   assert am2.context is ctx
 
 run(testAffineMapCapsule)
+
+
+# CHECK-LABEL: TEST: testAffineMapGet
+def testAffineMapGet():
+  with Context() as ctx:
+    d0 = AffineDimExpr.get(0)
+    d1 = AffineDimExpr.get(1)
+    c2 = AffineConstantExpr.get(2)
+
+    # CHECK: (d0, d1)[s0, s1, s2] -> ()
+    map0 = AffineMap.get(2, 3, [])
+    print(map0)
+
+    # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
+    map1 = AffineMap.get(2, 3, [d1, c2])
+    print(map1)
+
+    # CHECK: () -> (2)
+    map2 = AffineMap.get(0, 0, [c2])
+    print(map2)
+
+    # CHECK: (d0, d1) -> (d0, d1)
+    map3 = AffineMap.get(2, 0, [d0, d1])
+    print(map3)
+
+    # CHECK: (d0, d1) -> (d1)
+    map4 = AffineMap.get(2, 0, [d1])
+    print(map4)
+
+    # CHECK: (d0, d1, d2) -> (d2, d0, d1)
+    map5 = AffineMap.get_permutation([2, 0, 1])
+    print(map5)
+
+    assert map1 == AffineMap.get(2, 3, [d1, c2])
+    assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
+    assert map2 == AffineMap.get_constant(2)
+    assert map3 == AffineMap.get_identity(2)
+    assert map4 == AffineMap.get_minor_identity(2, 1)
+
+    try:
+      AffineMap.get(1, 1, [1])
+    except RuntimeError as e:
+      # CHECK: Invalid expression when attempting to create an AffineMap
+      print(e)
+
+    try:
+      AffineMap.get(1, 1, [None])
+    except RuntimeError as e:
+      # CHECK: Invalid expression (None?) when attempting to create an AffineMap
+      print(e)
+
+    try:
+      map3.get_submap([42])
+    except ValueError as e:
+      # CHECK: result position out of bounds
+      print(e)
+
+    try:
+      map3.get_minor_submap(42)
+    except ValueError as e:
+      # CHECK: number of results out of bounds
+      print(e)
+
+    try:
+      map3.get_major_submap(42)
+    except ValueError as e:
+      # CHECK: number of results out of bounds
+      print(e)
+
+run(testAffineMapGet)
+
+
+# CHECK-LABEL: TEST: testAffineMapDerive
+def testAffineMapDerive():
+  with Context() as ctx:
+    map5 = AffineMap.get_identity(5)
+
+    # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
+    map123 = map5.get_submap([1,2,3])
+    print(map123)
+
+    # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
+    map01 = map5.get_major_submap(2)
+    print(map01)
+
+    # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
+    map34 = map5.get_minor_submap(2)
+    print(map34)
+
+run(testAffineMapDerive)
+
+
+# CHECK-LABEL: TEST: testAffineMapProperties
+def testAffineMapProperties():
+  with Context():
+    d0 = AffineDimExpr.get(0)
+    d1 = AffineDimExpr.get(1)
+    d2 = AffineDimExpr.get(2)
+    map1 = AffineMap.get(3, 0, [d2, d0])
+    map2 = AffineMap.get(3, 0, [d2, d0, d1])
+    map3 = AffineMap.get(3, 1, [d2, d0, d1])
+    # CHECK: False
+    print(map1.is_permutation)
+    # CHECK: True
+    print(map1.is_projected_permutation)
+    # CHECK: True
+    print(map2.is_permutation)
+    # CHECK: True
+    print(map2.is_projected_permutation)
+    # CHECK: False
+    print(map3.is_permutation)
+    # CHECK: False
+    print(map3.is_projected_permutation)
+
+run(testAffineMapProperties)
+
+
+# CHECK-LABEL: TEST: testAffineMapExprs
+def testAffineMapExprs():
+  with Context():
+    d0 = AffineDimExpr.get(0)
+    d1 = AffineDimExpr.get(1)
+    d2 = AffineDimExpr.get(2)
+    map3 = AffineMap.get(3, 1, [d2, d0, d1])
+
+    # CHECK: 3
+    print(map3.n_dims)
+    # CHECK: 4
+    print(map3.n_inputs)
+    # CHECK: 1
+    print(map3.n_symbols)
+    assert map3.n_inputs == map3.n_dims + map3.n_symbols
+
+    # CHECK: 3
+    print(len(map3.results))
+    for expr in map3.results:
+      # CHECK: d2
+      # CHECK: d0
+      # CHECK: d1
+      print(expr)
+    for expr in map3.results[-1:-4:-1]:
+      # CHECK: d1
+      # CHECK: d0
+      # CHECK: d2
+      print(expr)
+    assert list(map3.results) == [d2, d0, d1]
+
+run(testAffineMapExprs)

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 550f799440f9..b4775978b17c 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1007,7 +1007,7 @@ int printBuiltinAttributes(MlirContext ctx) {
 
 int printAffineMap(MlirContext ctx) {
   MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx);
-  MlirAffineMap affineMap = mlirAffineMapGet(ctx, 3, 2);
+  MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, 3, 2);
   MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2);
   MlirAffineMap multiDimIdentityAffineMap =
       mlirAffineMapMultiDimIdentityGet(ctx, 3);
@@ -1275,6 +1275,29 @@ int printAffineExpr(MlirContext ctx) {
   return 0;
 }
 
+int affineMapFromExprs(MlirContext ctx) {
+  MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 0);
+  MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 1);
+  MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr};
+  MlirAffineMap map = mlirAffineMapGet(ctx, 3, 3, 2, exprs);
+
+  // CHECK-LABEL: @affineMapFromExprs
+  fprintf(stderr, "@affineMapFromExprs");
+  // CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1)
+  mlirAffineMapDump(map);
+
+  if (mlirAffineMapGetNumResults(map) != 2)
+    return 1;
+
+  if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 0), affineDimExpr))
+    return 2;
+
+  if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr))
+    return 3;
+
+  return 0;
+}
+
 int registerOnlyStd() {
   MlirContext ctx = mlirContextCreate();
   // The built-in dialect is always loaded.
@@ -1375,8 +1398,10 @@ int main() {
     return 4;
   if (printAffineExpr(ctx))
     return 5;
-  if (registerOnlyStd())
+  if (affineMapFromExprs(ctx))
     return 6;
+  if (registerOnlyStd())
+    return 7;
 
   mlirContextDestroy(ctx);
 


        


More information about the Mlir-commits mailing list