[Mlir-commits] [mlir] 47cc166 - [MLIR] [Python] Make Attribute and Type hashable

John Demme llvmlistbot at llvm.org
Wed Sep 22 20:00:20 PDT 2021


Author: John Demme
Date: 2021-09-22T19:59:03-07:00
New Revision: 47cc166bc023b497bdffe0964d80f15eaee8b7da

URL: https://github.com/llvm/llvm-project/commit/47cc166bc023b497bdffe0964d80f15eaee8b7da
DIFF: https://github.com/llvm/llvm-project/commit/47cc166bc023b497bdffe0964d80f15eaee8b7da.diff

LOG: [MLIR] [Python] Make Attribute and Type hashable

Enables putting types and attributes in sets and in dicts as keys.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110301

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/test/python/ir/attributes.py
    mlir/test/python/ir/builtin_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7763f4671d58c..473c94c900c01 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2364,6 +2364,7 @@ void mlir::python::populateIRCore(py::module &m) {
       .def("__eq__",
            [](PyAttribute &self, PyAttribute &other) { return self == other; })
       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
+      .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
       .def(
           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
           kDumpDocstring)
@@ -2457,6 +2458,7 @@ void mlir::python::populateIRCore(py::module &m) {
           "Context that owns the Type")
       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
       .def("__eq__", [](PyType &self, py::object &other) { return false; })
+      .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
       .def(
           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
       .def(

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index bdda2e3843acf..89da559669db8 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -57,6 +57,28 @@ def testAttrEq():
     print("a1 == None:", a1 == None)
 
 
+# CHECK-LABEL: TEST: testAttrHash
+ at run
+def testAttrHash():
+  with Context():
+    a1 = Attribute.parse('"attr1"')
+    a2 = Attribute.parse('"attr2"')
+    a3 = Attribute.parse('"attr1"')
+    # CHECK: hash(a1) == hash(a3): True
+    print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
+    # In general, hashes don't have to be unique. In this case, however, the
+    # hash is just the underlying pointer so it will be.
+    # CHECK: hash(a1) == hash(a2): False
+    print("hash(a1) == hash(a2):", a1.__hash__() == a2.__hash__())
+
+    s = set()
+    s.add(a1)
+    s.add(a2)
+    s.add(a3)
+    # CHECK: len(s): 2
+    print("len(s): ", len(s))
+
+
 # CHECK-LABEL: TEST: testAttrCast
 @run
 def testAttrCast():
@@ -382,4 +404,3 @@ def testArrayAttr():
     except RuntimeError as e:
       # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
       print("Error: ", e)
-

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 053e2ef3423ed..ab6502e1d61fc 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -57,6 +57,28 @@ def testTypeEq():
   print("t1 == None:", t1 == None)
 
 
+# CHECK-LABEL: TEST: testTypeHash
+ at run
+def testTypeHash():
+  ctx = Context()
+  t1 = Type.parse("i32", ctx)
+  t2 = Type.parse("f32", ctx)
+  t3 = Type.parse("i32", ctx)
+
+  # CHECK: hash(t1) == hash(t3): True
+  print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__())
+  # In general, hashes don't have to be unique. In this case, however, the
+  # hash is just the underlying pointer so it will be.
+  # CHECK: hash(t1) == hash(t2): False
+  print("hash(t1) == hash(t2):", t1.__hash__() == t2.__hash__())
+
+  s = set()
+  s.add(t1)
+  s.add(t2)
+  s.add(t3)
+  # CHECK: len(s): 2
+  print("len(s): ", len(s))
+
 # CHECK-LABEL: TEST: testTypeCast
 @run
 def testTypeCast():


        


More information about the Mlir-commits mailing list