[Mlir-commits] [mlir] f78fe0b - [mlir][python] Make Operation and Value hashable

Alex Zinenko llvmlistbot at llvm.org
Wed Nov 3 02:40:08 PDT 2021


Author: rkayaith
Date: 2021-11-03T10:40:03+01:00
New Revision: f78fe0b7b83842929262f1f05e308efe82fc6ffa

URL: https://github.com/llvm/llvm-project/commit/f78fe0b7b83842929262f1f05e308efe82fc6ffa
DIFF: https://github.com/llvm/llvm-project/commit/f78fe0b7b83842929262f1f05e308efe82fc6ffa.diff

LOG: [mlir][python] Make Operation and Value hashable

This allows operations and values to be used as dict keys

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D112669

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/test/python/ir/attributes.py
    mlir/test/python/ir/builtin_types.py
    mlir/test/python/ir/operation.py
    mlir/test/python/ir/value.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8f451cf34bed3..d465c1382459c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2171,6 +2171,10 @@ void mlir::python::populateIRCore(py::module &m) {
            })
       .def("__eq__",
            [](PyOperationBase &self, py::object other) { return false; })
+      .def("__hash__",
+           [](PyOperationBase &self) {
+             return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
+           })
       .def_property_readonly("attributes",
                              [](PyOperationBase &self) {
                                return PyOpAttributeMap(
@@ -2558,7 +2562,10 @@ void mlir::python::populateIRCore(py::module &m) {
       .def("__eq__",
            [](PyAttribute &self, PyAttribute &other) { return self == other; })
       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
-      .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
+      .def("__hash__",
+           [](PyAttribute &self) {
+             return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+           })
       .def(
           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
           kDumpDocstring)
@@ -2652,7 +2659,10 @@ void mlir::python::populateIRCore(py::module &m) {
           "Context that owns the Type")
       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
       .def("__eq__", [](PyType &self, py::object &other) { return false; })
-      .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
+      .def("__hash__",
+           [](PyType &self) {
+             return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+           })
       .def(
           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
       .def(
@@ -2703,6 +2713,10 @@ void mlir::python::populateIRCore(py::module &m) {
              return self.get().ptr == other.get().ptr;
            })
       .def("__eq__", [](PyValue &self, py::object other) { return false; })
+      .def("__hash__",
+           [](PyValue &self) {
+             return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+           })
       .def(
           "__str__",
           [](PyValue &self) {

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 9bbf23cf20855..5f8dd0ad1183f 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -66,10 +66,6 @@ def testAttrHash():
     a3 = Attribute.parse('"attr1"')
     # CHECK: hash(a1) == hash(a3): True
     print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
-    # In general, hashes don't have to be unique. In this case, however, the
-    # hash is just the underlying pointer so it will be.
-    # CHECK: hash(a1) == hash(a2): False
-    print("hash(a1) == hash(a2):", a1.__hash__() == a2.__hash__())
 
     s = set()
     s.add(a1)

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 911391f2d528b..c5b32e8ea0183 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -67,10 +67,6 @@ def testTypeHash():
 
   # CHECK: hash(t1) == hash(t3): True
   print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__())
-  # In general, hashes don't have to be unique. In this case, however, the
-  # hash is just the underlying pointer so it will be.
-  # CHECK: hash(t1) == hash(t2): False
-  print("hash(t1) == hash(t2):", t1.__hash__() == t2.__hash__())
 
   s = set()
   s.add(t1)

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 950c31217fa22..8771ca046b8b8 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -741,6 +741,7 @@ def testOperationLoc():
     assert op.location == loc
     assert op.operation.location == loc
 
+
 # CHECK-LABEL: TEST: testModuleMerge
 @run
 def testModuleMerge():
@@ -876,3 +877,13 @@ def testSymbolTable():
         raise
     else:
       assert False, "exepcted ValueError when adding a non-symbol"
+
+
+# CHECK-LABEL: TEST: testOperationHash
+ at run
+def testOperationHash():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  with ctx, Location.unknown():
+    op = Operation.create("custom.op1")
+    assert hash(op) == hash(op.operation)

diff  --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 230025a8f3306..cfbfdf35fcf6f 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -55,3 +55,22 @@ def testValueIsInstance():
   op = func.regions[0].blocks[0].operations[0]
   assert not BlockArgument.isinstance(op.results[0])
   assert OpResult.isinstance(op.results[0])
+
+
+# CHECK-LABEL: TEST: testValueHash
+ at run
+def testValueHash():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  module = Module.parse(
+      r"""
+    func @foo(%arg0: f32) -> f32 {
+      %0 = "some_dialect.some_op"(%arg0) : (f32) -> f32
+      return %0 : f32
+    }""", ctx)
+
+  [func] = module.body.operations
+  block = func.entry_block
+  op, ret = block.operations
+  assert hash(block.arguments[0]) == hash(op.operands[0])
+  assert hash(op.result) == hash(ret.operands[0])


        


More information about the Mlir-commits mailing list