[Mlir-commits] [mlir] fc7594c - [mlir][python] improve usability of Python affine construct bindings

Alex Zinenko llvmlistbot at llvm.org
Wed Nov 3 02:48:07 PDT 2021


Author: Alex Zinenko
Date: 2021-11-03T10:48:01+01:00
New Revision: fc7594cc4aa5e652fe61f278a13e865141797245

URL: https://github.com/llvm/llvm-project/commit/fc7594cc4aa5e652fe61f278a13e865141797245
DIFF: https://github.com/llvm/llvm-project/commit/fc7594cc4aa5e652fe61f278a13e865141797245.diff

LOG: [mlir][python] improve usability of Python affine construct bindings

- Provide the operator overloads for constructing (semi-)affine expressions in
  Python by combining existing expressions with constants.
- Make AffineExpr, AffineMap and IntegerSet hashable in Python.
- Expose the AffineExpr composition functionality.

Reviewed By: gysit, aoyal

Differential Revision: https://reviews.llvm.org/D113010

Added: 
    

Modified: 
    mlir/include/mlir-c/AffineExpr.h
    mlir/lib/Bindings/Python/IRAffine.cpp
    mlir/lib/CAPI/IR/AffineExpr.cpp
    mlir/test/CAPI/ir.c
    mlir/test/python/ir/affine_expr.py
    mlir/test/python/ir/affine_map.py
    mlir/test/python/ir/integer_set.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h
index 5516f29088e43..14e951ddee9ad 100644
--- a/mlir/include/mlir-c/AffineExpr.h
+++ b/mlir/include/mlir-c/AffineExpr.h
@@ -39,6 +39,8 @@ DEFINE_C_API_STRUCT(MlirAffineExpr, const void);
 
 #undef DEFINE_C_API_STRUCT
 
+struct MlirAffineMap;
+
 /// Gets the context that owns the affine expression.
 MLIR_CAPI_EXPORTED MlirContext
 mlirAffineExprGetContext(MlirAffineExpr affineExpr);
@@ -86,6 +88,10 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsMultipleOf(MlirAffineExpr affineExpr,
 MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
                                                       intptr_t position);
 
+/// Composes the given map with the given expression.
+MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose(
+    MlirAffineExpr affineExpr, struct MlirAffineMap affineMap);
+
 //===----------------------------------------------------------------------===//
 // Affine Dimension Expression.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 50a96c8c8cede..da80cda9c5823 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -205,6 +205,18 @@ class PyAffineAddExpr
     return PyAffineAddExpr(lhs.getContext(), expr);
   }
 
+  static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
+    MlirAffineExpr expr = mlirAffineAddExprGet(
+        lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
+    return PyAffineAddExpr(lhs.getContext(), expr);
+  }
+
+  static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineAddExprGet(
+        mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
+    return PyAffineAddExpr(rhs.getContext(), expr);
+  }
+
   static void bindDerived(ClassTy &c) {
     c.def_static("get", &PyAffineAddExpr::get);
   }
@@ -222,6 +234,18 @@ class PyAffineMulExpr
     return PyAffineMulExpr(lhs.getContext(), expr);
   }
 
+  static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
+    MlirAffineExpr expr = mlirAffineMulExprGet(
+        lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
+    return PyAffineMulExpr(lhs.getContext(), expr);
+  }
+
+  static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineMulExprGet(
+        mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
+    return PyAffineMulExpr(rhs.getContext(), expr);
+  }
+
   static void bindDerived(ClassTy &c) {
     c.def_static("get", &PyAffineMulExpr::get);
   }
@@ -239,6 +263,18 @@ class PyAffineModExpr
     return PyAffineModExpr(lhs.getContext(), expr);
   }
 
+  static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
+    MlirAffineExpr expr = mlirAffineModExprGet(
+        lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
+    return PyAffineModExpr(lhs.getContext(), expr);
+  }
+
+  static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineModExprGet(
+        mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
+    return PyAffineModExpr(rhs.getContext(), expr);
+  }
+
   static void bindDerived(ClassTy &c) {
     c.def_static("get", &PyAffineModExpr::get);
   }
@@ -256,6 +292,18 @@ class PyAffineFloorDivExpr
     return PyAffineFloorDivExpr(lhs.getContext(), expr);
   }
 
+  static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
+    MlirAffineExpr expr = mlirAffineFloorDivExprGet(
+        lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
+    return PyAffineFloorDivExpr(lhs.getContext(), expr);
+  }
+
+  static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineFloorDivExprGet(
+        mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
+    return PyAffineFloorDivExpr(rhs.getContext(), expr);
+  }
+
   static void bindDerived(ClassTy &c) {
     c.def_static("get", &PyAffineFloorDivExpr::get);
   }
@@ -273,6 +321,18 @@ class PyAffineCeilDivExpr
     return PyAffineCeilDivExpr(lhs.getContext(), expr);
   }
 
+  static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
+    MlirAffineExpr expr = mlirAffineCeilDivExprGet(
+        lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
+    return PyAffineCeilDivExpr(lhs.getContext(), expr);
+  }
+
+  static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineCeilDivExprGet(
+        mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
+    return PyAffineCeilDivExpr(rhs.getContext(), expr);
+  }
+
   static void bindDerived(ClassTy &c) {
     c.def_static("get", &PyAffineCeilDivExpr::get);
   }
@@ -435,17 +495,19 @@ void mlir::python::populateIRAffine(py::module &m) {
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyAffineExpr::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
-      .def("__add__",
-           [](PyAffineExpr &self, PyAffineExpr &other) {
-             return PyAffineAddExpr::get(self, other);
-           })
-      .def("__mul__",
-           [](PyAffineExpr &self, PyAffineExpr &other) {
-             return PyAffineMulExpr::get(self, other);
-           })
-      .def("__mod__",
-           [](PyAffineExpr &self, PyAffineExpr &other) {
-             return PyAffineModExpr::get(self, other);
+      .def("__add__", &PyAffineAddExpr::get)
+      .def("__add__", &PyAffineAddExpr::getRHSConstant)
+      .def("__radd__", &PyAffineAddExpr::getRHSConstant)
+      .def("__mul__", &PyAffineMulExpr::get)
+      .def("__mul__", &PyAffineMulExpr::getRHSConstant)
+      .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
+      .def("__mod__", &PyAffineModExpr::get)
+      .def("__mod__", &PyAffineModExpr::getRHSConstant)
+      .def("__rmod__",
+           [](PyAffineExpr &self, intptr_t other) {
+             return PyAffineModExpr::get(
+                 PyAffineConstantExpr::get(other, *self.getContext().get()),
+                 self);
            })
       .def("__sub__",
            [](PyAffineExpr &self, PyAffineExpr &other) {
@@ -454,6 +516,17 @@ void mlir::python::populateIRAffine(py::module &m) {
              return PyAffineAddExpr::get(self,
                                          PyAffineMulExpr::get(negOne, other));
            })
+      .def("__sub__",
+           [](PyAffineExpr &self, intptr_t other) {
+             return PyAffineAddExpr::get(
+                 self,
+                 PyAffineConstantExpr::get(-other, *self.getContext().get()));
+           })
+      .def("__rsub__",
+           [](PyAffineExpr &self, intptr_t other) {
+             return PyAffineAddExpr::getLHSConstant(
+                 other, PyAffineMulExpr::getLHSConstant(-1, self));
+           })
       .def("__eq__", [](PyAffineExpr &self,
                         PyAffineExpr &other) { return self == other; })
       .def("__eq__",
@@ -474,24 +547,63 @@ void mlir::python::populateIRAffine(py::module &m) {
              printAccum.parts.append(")");
              return printAccum.join();
            })
+      .def("__hash__",
+           [](PyAffineExpr &self) {
+             return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+           })
       .def_property_readonly(
           "context",
           [](PyAffineExpr &self) { return self.getContext().getObject(); })
+      .def("compose",
+           [](PyAffineExpr &self, PyAffineMap &other) {
+             return PyAffineExpr(self.getContext(),
+                                 mlirAffineExprCompose(self, other));
+           })
       .def_static(
           "get_add", &PyAffineAddExpr::get,
           "Gets an affine expression containing a sum of two expressions.")
+      .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
+                  "Gets an affine expression containing a sum of a constant "
+                  "and another expression.")
+      .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
+                  "Gets an affine expression containing a sum of an expression "
+                  "and a constant.")
       .def_static(
           "get_mul", &PyAffineMulExpr::get,
           "Gets an affine expression containing a product of two expressions.")
+      .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
+                  "Gets an affine expression containing a product of a "
+                  "constant and another expression.")
+      .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
+                  "Gets an affine expression containing a product of an "
+                  "expression and a constant.")
       .def_static("get_mod", &PyAffineModExpr::get,
                   "Gets an affine expression containing the modulo of dividing "
                   "one expression by another.")
+      .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
+                  "Gets a semi-affine expression containing the modulo of "
+                  "dividing a constant by an expression.")
+      .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
+                  "Gets an affine expression containing the module of dividing"
+                  "an expression by a constant.")
       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
                   "Gets an affine expression containing the rounded-down "
                   "result of dividing one expression by another.")
+      .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
+                  "Gets a semi-affine expression containing the rounded-down "
+                  "result of dividing a constant by an expression.")
+      .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
+                  "Gets an affine expression containing the rounded-down "
+                  "result of dividing an expression by a constant.")
       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
                   "Gets an affine expression containing the rounded-up result "
                   "of dividing one expression by another.")
+      .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
+                  "Gets a semi-affine expression containing the rounded-up "
+                  "result of dividing a constant by an expression.")
+      .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
+                  "Gets an affine expression containing the rounded-up result "
+                  "of dividing an expression by a constant.")
       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
                   py::arg("context") = py::none(),
                   "Gets a constant affine expression with the given value.")
@@ -542,6 +654,10 @@ void mlir::python::populateIRAffine(py::module &m) {
              printAccum.parts.append(")");
              return printAccum.join();
            })
+      .def("__hash__",
+           [](PyAffineMap &self) {
+             return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+           })
       .def_static("compress_unused_symbols",
                   [](py::list affineMaps, DefaultingPyMlirContext context) {
                     SmallVector<MlirAffineMap> maps;
@@ -714,6 +830,10 @@ void mlir::python::populateIRAffine(py::module &m) {
              printAccum.parts.append(")");
              return printAccum.join();
            })
+      .def("__hash__",
+           [](PyIntegerSet &self) {
+             return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+           })
       .def_property_readonly(
           "context",
           [](PyIntegerSet &self) { return self.getContext().getObject(); })

diff  --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp
index 2d8bc3ce569af..5b25ab5337e2f 100644
--- a/mlir/lib/CAPI/IR/AffineExpr.cpp
+++ b/mlir/lib/CAPI/IR/AffineExpr.cpp
@@ -56,6 +56,11 @@ bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
   return unwrap(affineExpr).isFunctionOfDim(position);
 }
 
+MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr,
+                                     MlirAffineMap affineMap) {
+  return wrap(unwrap(affineExpr).compose(unwrap(affineMap)));
+}
+
 //===----------------------------------------------------------------------===//
 // Affine Dimension Expression.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index ef555377c1ce5..1056f65080be7 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1393,6 +1393,13 @@ int affineMapFromExprs(MlirContext ctx) {
   if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr))
     return 3;
 
+  MlirAffineExpr affineDim2Expr = mlirAffineDimExprGet(ctx, 1);
+  MlirAffineExpr composed = mlirAffineExprCompose(affineDim2Expr, map);
+  // CHECK: s1
+  mlirAffineExprDump(composed);
+  if (!mlirAffineExprEqual(composed, affineSymbolExpr))
+    return 4;
+
   return 0;
 }
 

diff  --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py
index 184466870a578..9854b496fe460 100644
--- a/mlir/test/python/ir/affine_expr.py
+++ b/mlir/test/python/ir/affine_expr.py
@@ -137,6 +137,14 @@ def testAffineAddExpr():
     # CHECK: d1 + d2
     print(d12op)
 
+    d1cst_op = d1 + 2
+    # CHECK: d1 + 2
+    print(d1cst_op)
+
+    d1cst_op2 = 2 + d1
+    # CHECK: d1 + 2
+    print(d1cst_op2)
+
     assert d12 == d12op
     assert d12.lhs == d1
     assert d12.rhs == d2
@@ -156,7 +164,16 @@ def testAffineMulExpr():
     op = d1 * c2
     print(op)
 
+    # CHECK: d1 * 2
+    op_cst = d1 * 2
+    print(op_cst)
+
+    # CHECK: d1 * 2
+    op_cst2 = 2 * d1
+    print(op_cst2)
+
     assert expr == op
+    assert expr == op_cst
     assert expr.lhs == d1
     assert expr.rhs == c2
 
@@ -175,10 +192,32 @@ def testAffineModExpr():
     op = d1 % c2
     print(op)
 
+    # CHECK: d1 mod 2
+    op_cst = d1 % 2
+    print(op_cst)
+
+    # CHECK: 2 mod d1
+    print(2 % d1)
+
     assert expr == op
+    assert expr == op_cst
     assert expr.lhs == d1
     assert expr.rhs == c2
 
+    expr2 = AffineExpr.get_mod(c2, d1)
+    expr3 = AffineExpr.get_mod(2, d1)
+    expr4 = AffineExpr.get_mod(d1, 2)
+
+    # CHECK: 2 mod d1
+    print(expr2)
+    # CHECK: 2 mod d1
+    print(expr3)
+    # CHECK: d1 mod 2
+    print(expr4)
+
+    assert expr2 == expr3
+    assert expr4 == expr
+
 
 # CHECK-LABEL: TEST: testAffineFloorDivExpr
 @run
@@ -193,6 +232,20 @@ def testAffineFloorDivExpr():
     assert expr.lhs == d1
     assert expr.rhs == c2
 
+    expr2 = AffineExpr.get_floor_div(c2, d1)
+    expr3 = AffineExpr.get_floor_div(2, d1)
+    expr4 = AffineExpr.get_floor_div(d1, 2)
+
+    # CHECK: 2 floordiv d1
+    print(expr2)
+    # CHECK: 2 floordiv d1
+    print(expr3)
+    # CHECK: d1 floordiv 2
+    print(expr4)
+
+    assert expr2 == expr3
+    assert expr4 == expr
+
 
 # CHECK-LABEL: TEST: testAffineCeilDivExpr
 @run
@@ -207,6 +260,20 @@ def testAffineCeilDivExpr():
     assert expr.lhs == d1
     assert expr.rhs == c2
 
+    expr2 = AffineExpr.get_ceil_div(c2, d1)
+    expr3 = AffineExpr.get_ceil_div(2, d1)
+    expr4 = AffineExpr.get_ceil_div(d1, 2)
+
+    # CHECK: 2 ceildiv d1
+    print(expr2)
+    # CHECK: 2 ceildiv d1
+    print(expr3)
+    # CHECK: d1 ceildiv 2
+    print(expr4)
+
+    assert expr2 == expr3
+    assert expr4 == expr
+
 
 # CHECK-LABEL: TEST: testAffineExprSub
 @run
@@ -225,6 +292,15 @@ def testAffineExprSub():
     # CHECK: -1
     print(rhs.rhs)
 
+    # CHECK: d1 - 42
+    print(d1 - 42)
+    # CHECK: -d1 + 42
+    print(42 - d1)
+
+    c42 = AffineConstantExpr.get(42)
+    assert d1 - 42 == d1 - c42
+    assert 42 - d1 == c42 - d1
+
 # CHECK-LABEL: TEST: testClassHierarchy
 @run
 def testClassHierarchy():
@@ -289,3 +365,38 @@ def testIsInstance():
     print(AffineMulExpr.isinstance(mul))
     # CHECK: False
     print(AffineAddExpr.isinstance(mul))
+
+
+# CHECK-LABEL: TEST: testCompose
+ at run
+def testCompose():
+  with Context():
+    # d0 + d2.
+    expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2))
+
+    # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)
+    map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1))
+    map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0))
+    map3 = AffineAddExpr.get(
+        AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)),
+        AffineDimExpr.get(2))
+    map = AffineMap.get(3, 2, [map1, map2, map3])
+
+    # CHECK: d0 + s1 + d0 + d1 + d2
+    print(expr.compose(map))
+
+
+# CHECK-LABEL: TEST: testHash
+ at run
+def testHash():
+  with Context():
+    d0 = AffineDimExpr.get(0)
+    s1 = AffineSymbolExpr.get(1)
+    assert hash(d0) == hash(AffineDimExpr.get(0))
+    assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1))
+
+    dictionary = dict()
+    dictionary[d0] = 0
+    dictionary[s1] = 1
+    assert d0 in dictionary
+    assert s1 in dictionary

diff  --git a/mlir/test/python/ir/affine_map.py b/mlir/test/python/ir/affine_map.py
index da5d230f42cde..52c7261500c90 100644
--- a/mlir/test/python/ir/affine_map.py
+++ b/mlir/test/python/ir/affine_map.py
@@ -9,9 +9,11 @@ def run(f):
   f()
   gc.collect()
   assert Context._get_live_count() == 0
+  return f
 
 
 # CHECK-LABEL: TEST: testAffineMapCapsule
+ at run
 def testAffineMapCapsule():
   with Context() as ctx:
     am1 = AffineMap.get_empty(ctx)
@@ -23,10 +25,8 @@ def testAffineMapCapsule():
   assert am2.context is ctx
 
 
-run(testAffineMapCapsule)
-
-
 # CHECK-LABEL: TEST: testAffineMapGet
+ at run
 def testAffineMapGet():
   with Context() as ctx:
     d0 = AffineDimExpr.get(0)
@@ -100,10 +100,8 @@ def testAffineMapGet():
       print(e)
 
 
-run(testAffineMapGet)
-
-
 # CHECK-LABEL: TEST: testAffineMapDerive
+ at run
 def testAffineMapDerive():
   with Context() as ctx:
     map5 = AffineMap.get_identity(5)
@@ -121,10 +119,8 @@ def testAffineMapDerive():
     print(map34)
 
 
-run(testAffineMapDerive)
-
-
 # CHECK-LABEL: TEST: testAffineMapProperties
+ at run
 def testAffineMapProperties():
   with Context():
     d0 = AffineDimExpr.get(0)
@@ -147,10 +143,8 @@ def testAffineMapProperties():
     print(map3.is_projected_permutation)
 
 
-run(testAffineMapProperties)
-
-
 # CHECK-LABEL: TEST: testAffineMapExprs
+ at run
 def testAffineMapExprs():
   with Context():
     d0 = AffineDimExpr.get(0)
@@ -181,10 +175,8 @@ def testAffineMapExprs():
     assert list(map3.results) == [d2, d0, d1]
 
 
-run(testAffineMapExprs)
-
-
 # CHECK-LABEL: TEST: testCompressUnusedSymbols
+ at run
 def testCompressUnusedSymbols():
   with Context() as ctx:
     d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
@@ -210,10 +202,8 @@ def testCompressUnusedSymbols():
     print(compressed_maps)
 
 
-run(testCompressUnusedSymbols)
-
-
 # CHECK-LABEL: TEST: testReplace
+ at run
 def testReplace():
   with Context() as ctx:
     d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
@@ -236,4 +226,16 @@ def testReplace():
     print(replace3)
 
 
-run(testReplace)
+# CHECK-LABEL: TEST: testHash
+ at run
+def testHash():
+  with Context():
+    d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
+    m1 = AffineMap.get(2, 0, [d0, d1])
+    m2 = AffineMap.get(2, 0, [d1, d0])
+    assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))
+
+    dictionary = dict()
+    dictionary[m1] = 1
+    dictionary[m2] = 2
+    assert m1 in dictionary

diff  --git a/mlir/test/python/ir/integer_set.py b/mlir/test/python/ir/integer_set.py
index bdec8afba0ebf..b916d9ab386e9 100644
--- a/mlir/test/python/ir/integer_set.py
+++ b/mlir/test/python/ir/integer_set.py
@@ -8,9 +8,11 @@ def run(f):
   f()
   gc.collect()
   assert Context._get_live_count() == 0
+  return f
 
 
 # CHECK-LABEL: TEST: testIntegerSetCapsule
+ at run
 def testIntegerSetCapsule():
   with Context() as ctx:
     is1 = IntegerSet.get_empty(1, 1, ctx)
@@ -21,10 +23,9 @@ def testIntegerSetCapsule():
   assert is1 == is2
   assert is2.context is ctx
 
-run(testIntegerSetCapsule)
-
 
 # CHECK-LABEL: TEST: testIntegerSetGet
+ at run
 def testIntegerSetGet():
   with Context():
     d0 = AffineDimExpr.get(0)
@@ -92,10 +93,9 @@ def testIntegerSetGet():
       # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
       print(e)
 
-run(testIntegerSetGet)
-
 
 # CHECK-LABEL: TEST: testIntegerSetProperties
+ at run
 def testIntegerSetProperties():
   with Context():
     d0 = AffineDimExpr.get(0)
@@ -125,4 +125,17 @@ def testIntegerSetProperties():
       print(cstr.expr, end='')
       print(" == 0" if cstr.is_eq else " >= 0")
 
-run(testIntegerSetProperties)
+
+# CHECK_LABEL: TEST: testHash
+ at run
+def testHash():
+  with Context():
+    d0 = AffineDimExpr.get(0)
+    d1 = AffineDimExpr.get(1)
+    set = IntegerSet.get(2, 0, [d0 + d1], [True])
+
+    assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True]))
+
+    dictionary = dict()
+    dictionary[set] = 42
+    assert set in dictionary


        


More information about the Mlir-commits mailing list