[Mlir-commits] [mlir] f78fe0b - [mlir][python] Make Operation and Value hashable
Alex Zinenko
llvmlistbot at llvm.org
Wed Nov 3 02:40:08 PDT 2021
Author: rkayaith
Date: 2021-11-03T10:40:03+01:00
New Revision: f78fe0b7b83842929262f1f05e308efe82fc6ffa
URL: https://github.com/llvm/llvm-project/commit/f78fe0b7b83842929262f1f05e308efe82fc6ffa
DIFF: https://github.com/llvm/llvm-project/commit/f78fe0b7b83842929262f1f05e308efe82fc6ffa.diff
LOG: [mlir][python] Make Operation and Value hashable
This allows operations and values to be used as dict keys
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D112669
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/test/python/ir/attributes.py
mlir/test/python/ir/builtin_types.py
mlir/test/python/ir/operation.py
mlir/test/python/ir/value.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8f451cf34bed3..d465c1382459c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2171,6 +2171,10 @@ void mlir::python::populateIRCore(py::module &m) {
})
.def("__eq__",
[](PyOperationBase &self, py::object other) { return false; })
+ .def("__hash__",
+ [](PyOperationBase &self) {
+ return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
+ })
.def_property_readonly("attributes",
[](PyOperationBase &self) {
return PyOpAttributeMap(
@@ -2558,7 +2562,10 @@ void mlir::python::populateIRCore(py::module &m) {
.def("__eq__",
[](PyAttribute &self, PyAttribute &other) { return self == other; })
.def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
- .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
+ .def("__hash__",
+ [](PyAttribute &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ })
.def(
"dump", [](PyAttribute &self) { mlirAttributeDump(self); },
kDumpDocstring)
@@ -2652,7 +2659,10 @@ void mlir::python::populateIRCore(py::module &m) {
"Context that owns the Type")
.def("__eq__", [](PyType &self, PyType &other) { return self == other; })
.def("__eq__", [](PyType &self, py::object &other) { return false; })
- .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
+ .def("__hash__",
+ [](PyType &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ })
.def(
"dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
.def(
@@ -2703,6 +2713,10 @@ void mlir::python::populateIRCore(py::module &m) {
return self.get().ptr == other.get().ptr;
})
.def("__eq__", [](PyValue &self, py::object other) { return false; })
+ .def("__hash__",
+ [](PyValue &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ })
.def(
"__str__",
[](PyValue &self) {
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 9bbf23cf20855..5f8dd0ad1183f 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -66,10 +66,6 @@ def testAttrHash():
a3 = Attribute.parse('"attr1"')
# CHECK: hash(a1) == hash(a3): True
print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
- # In general, hashes don't have to be unique. In this case, however, the
- # hash is just the underlying pointer so it will be.
- # CHECK: hash(a1) == hash(a2): False
- print("hash(a1) == hash(a2):", a1.__hash__() == a2.__hash__())
s = set()
s.add(a1)
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 911391f2d528b..c5b32e8ea0183 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -67,10 +67,6 @@ def testTypeHash():
# CHECK: hash(t1) == hash(t3): True
print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__())
- # In general, hashes don't have to be unique. In this case, however, the
- # hash is just the underlying pointer so it will be.
- # CHECK: hash(t1) == hash(t2): False
- print("hash(t1) == hash(t2):", t1.__hash__() == t2.__hash__())
s = set()
s.add(t1)
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 950c31217fa22..8771ca046b8b8 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -741,6 +741,7 @@ def testOperationLoc():
assert op.location == loc
assert op.operation.location == loc
+
# CHECK-LABEL: TEST: testModuleMerge
@run
def testModuleMerge():
@@ -876,3 +877,13 @@ def testSymbolTable():
raise
else:
assert False, "exepcted ValueError when adding a non-symbol"
+
+
+# CHECK-LABEL: TEST: testOperationHash
+ at run
+def testOperationHash():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with ctx, Location.unknown():
+ op = Operation.create("custom.op1")
+ assert hash(op) == hash(op.operation)
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 230025a8f3306..cfbfdf35fcf6f 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -55,3 +55,22 @@ def testValueIsInstance():
op = func.regions[0].blocks[0].operations[0]
assert not BlockArgument.isinstance(op.results[0])
assert OpResult.isinstance(op.results[0])
+
+
+# CHECK-LABEL: TEST: testValueHash
+ at run
+def testValueHash():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
+ func @foo(%arg0: f32) -> f32 {
+ %0 = "some_dialect.some_op"(%arg0) : (f32) -> f32
+ return %0 : f32
+ }""", ctx)
+
+ [func] = module.body.operations
+ block = func.entry_block
+ op, ret = block.operations
+ assert hash(block.arguments[0]) == hash(op.operands[0])
+ assert hash(op.result) == hash(ret.operands[0])
More information about the Mlir-commits
mailing list