[Mlir-commits] [mlir] 974c159 - [mlir][python] Downcast attributes in more places

Rahul Kayaith llvmlistbot at llvm.org
Mon Jul 10 19:01:40 PDT 2023


Author: Rahul Kayaith
Date: 2023-07-10T22:01:34-04:00
New Revision: 974c1596abdea379ec468bf14c3681fad7a53987

URL: https://github.com/llvm/llvm-project/commit/974c1596abdea379ec468bf14c3681fad7a53987
DIFF: https://github.com/llvm/llvm-project/commit/974c1596abdea379ec468bf14c3681fad7a53987.diff

LOG: [mlir][python] Downcast attributes in more places

Update remaining `PyAttribute`-returning APIs to return `MlirAttribute` instead,
so that they go through the downcasting mechanism.

Reviewed By: makslevental

Differential Revision: https://reviews.llvm.org/D154462

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/test/python/ir/array_attributes.py
    mlir/test/python/ir/attributes.py
    mlir/test/python/ir/builtin_types.py
    mlir/test/python/ir/operation.py
    mlir/test/python/ir/symbol_table.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 4ee06fa7a6d751..84a48a890eb409 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -270,12 +270,11 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
 
     PyArrayAttributeIterator &dunderIter() { return *this; }
 
-    PyAttribute dunderNext() {
+    MlirAttribute dunderNext() {
       // TODO: Throw is an inefficient way to stop iteration.
       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
         throw py::stop_iteration();
-      return PyAttribute(attr.getContext(),
-                         mlirArrayAttrGetElement(attr.get(), nextIndex++));
+      return mlirArrayAttrGetElement(attr.get(), nextIndex++);
     }
 
     static void bind(py::module &m) {
@@ -290,8 +289,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
     int nextIndex = 0;
   };
 
-  PyAttribute getItem(intptr_t i) {
-    return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
+  MlirAttribute getItem(intptr_t i) {
+    return mlirArrayAttrGetElement(*this, i);
   }
 
   static void bindDerived(ClassTy &c) {
@@ -843,13 +842,11 @@ class PyDenseElementsAttribute
                                  return mlirDenseElementsAttrIsSplat(self);
                                })
         .def("get_splat_value",
-             [](PyDenseElementsAttribute &self) -> PyAttribute {
-               if (!mlirDenseElementsAttrIsSplat(self)) {
+             [](PyDenseElementsAttribute &self) {
+               if (!mlirDenseElementsAttrIsSplat(self))
                  throw py::value_error(
                      "get_splat_value called on a non-splat attribute");
-               }
-               return PyAttribute(self.getContext(),
-                                  mlirDenseElementsAttrGetSplatValue(self));
+               return mlirDenseElementsAttrGetSplatValue(self);
              })
         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
   }
@@ -1018,10 +1015,9 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
       MlirAttribute attr =
           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
-      if (mlirAttributeIsNull(attr)) {
+      if (mlirAttributeIsNull(attr))
         throw py::key_error("attempt to access a non-existent attribute");
-      }
-      return PyAttribute(self.getContext(), attr);
+      return attr;
     });
     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
       if (index < 0 || index >= self.dunderLen()) {

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 3ab6d57b41690d..6c0b4a0604e31d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1908,19 +1908,17 @@ void PySymbolTable::dunderDel(const std::string &name) {
   erase(py::cast<PyOperationBase &>(operation));
 }
 
-PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
+MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
   operation->checkValid();
   symbol.getOperation().checkValid();
   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
   if (mlirAttributeIsNull(symbolAttr))
     throw py::value_error("Expected operation to have a symbol name.");
-  return PyAttribute(
-      symbol.getOperation().getContext(),
-      mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
+  return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
 }
 
-PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
+MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
   // Op must already be a symbol.
   PyOperation &operation = symbol.getOperation();
   operation.checkValid();
@@ -1929,7 +1927,7 @@ PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
       mlirOperationGetAttributeByName(operation.get(), attrName);
   if (mlirAttributeIsNull(existingNameAttr))
     throw py::value_error("Expected operation to have a symbol name.");
-  return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
+  return existingNameAttr;
 }
 
 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
@@ -1947,7 +1945,7 @@ void PySymbolTable::setSymbolName(PyOperationBase &symbol,
   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
 }
 
-PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
+MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
   PyOperation &operation = symbol.getOperation();
   operation.checkValid();
   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
@@ -1955,7 +1953,7 @@ PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
       mlirOperationGetAttributeByName(operation.get(), attrName);
   if (mlirAttributeIsNull(existingVisAttr))
     throw py::value_error("Expected operation to have a symbol visibility.");
-  return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
+  return existingVisAttr;
 }
 
 void PySymbolTable::setVisibility(PyOperationBase &symbol,
@@ -2287,13 +2285,13 @@ class PyOpAttributeMap {
   PyOpAttributeMap(PyOperationRef operation)
       : operation(std::move(operation)) {}
 
-  PyAttribute dunderGetItemNamed(const std::string &name) {
+  MlirAttribute dunderGetItemNamed(const std::string &name) {
     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
                                                          toMlirStringRef(name));
     if (mlirAttributeIsNull(attr)) {
       throw py::key_error("attempt to access a non-existent attribute");
     }
-    return PyAttribute(operation->getContext(), attr);
+    return attr;
   }
 
   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 225580f0f45759..76acfe5e7790be 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -1174,14 +1174,14 @@ class PySymbolTable {
 
   /// Inserts the given operation into the symbol table. The operation must have
   /// the symbol trait.
-  PyAttribute insert(PyOperationBase &symbol);
+  MlirAttribute insert(PyOperationBase &symbol);
 
   /// Gets and sets the name of a symbol op.
-  static PyAttribute getSymbolName(PyOperationBase &symbol);
+  static MlirAttribute getSymbolName(PyOperationBase &symbol);
   static void setSymbolName(PyOperationBase &symbol, const std::string &name);
 
   /// Gets and sets the visibility of a symbol op.
-  static PyAttribute getVisibility(PyOperationBase &symbol);
+  static MlirAttribute getVisibility(PyOperationBase &symbol);
   static void setVisibility(PyOperationBase &symbol,
                             const std::string &visibility);
 

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index caf215be85baa8..a7ccfbea542f5c 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -505,11 +505,12 @@ class PyRankedTensorType
         py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
         "Create a ranked tensor type");
     c.def_property_readonly(
-        "encoding", [](PyRankedTensorType &self) -> std::optional<PyAttribute> {
+        "encoding",
+        [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
           MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
           if (mlirAttributeIsNull(encoding))
             return std::nullopt;
-          return PyAttribute(self.getContext(), encoding);
+          return encoding;
         });
   }
 };
@@ -570,9 +571,8 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
          py::arg("loc") = py::none(), "Create a memref type")
         .def_property_readonly(
             "layout",
-            [](PyMemRefType &self) -> PyAttribute {
-              MlirAttribute layout = mlirMemRefTypeGetLayout(self);
-              return PyAttribute(self.getContext(), layout);
+            [](PyMemRefType &self) -> MlirAttribute {
+              return mlirMemRefTypeGetLayout(self);
             },
             "The layout of the MemRef type.")
         .def_property_readonly(
@@ -584,9 +584,11 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
             "The layout of the MemRef type as an affine map.")
         .def_property_readonly(
             "memory_space",
-            [](PyMemRefType &self) -> PyAttribute {
+            [](PyMemRefType &self) -> std::optional<MlirAttribute> {
               MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
-              return PyAttribute(self.getContext(), a);
+              if (mlirAttributeIsNull(a))
+                return std::nullopt;
+              return a;
             },
             "Returns the memory space of the given MemRef type.");
   }
@@ -622,9 +624,11 @@ class PyUnrankedMemRefType
          py::arg("loc") = py::none(), "Create a unranked memref type")
         .def_property_readonly(
             "memory_space",
-            [](PyUnrankedMemRefType &self) -> PyAttribute {
-              MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
-              return PyAttribute(self.getContext(), a);
+            [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
+              MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+              if (mlirAttributeIsNull(a))
+                return std::nullopt;
+              return a;
             },
             "Returns the memory space of the given Unranked MemRef type.");
   }

diff  --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 5ce8bc66fcf96e..b592804013b545 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -47,7 +47,11 @@ def testGetDenseElementsSplatInt():
         print(attr)
         # CHECK: is_splat: True
         print("is_splat:", attr.is_splat)
-        assert attr.get_splat_value() == element
+
+        # CHECK: splat_value: IntegerAttr(555 : i32)
+        splat_value = attr.get_splat_value()
+        print("splat_value:", repr(splat_value))
+        assert splat_value == element
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 28729e86ccd4c0..d986cac17dd765 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -441,11 +441,11 @@ def testDictAttr():
 
         assert len(a) == 2
 
-        # CHECK: 42 : i32
-        print(a["integerattr"])
+        # CHECK: integerattr: IntegerAttr(42 : i32)
+        print("integerattr:", repr(a["integerattr"]))
 
-        # CHECK: "string"
-        print(a["stringattr"])
+        # CHECK: stringattr: StringAttr("string")
+        print("stringattr:", repr(a["stringattr"]))
 
         # CHECK: True
         print("stringattr" in a)
@@ -488,14 +488,14 @@ def testTypeAttr():
 @run
 def testArrayAttr():
     with Context():
-        raw = Attribute.parse("[42, true, vector<4xf32>]")
-    # CHECK: attr: [42, true, vector<4xf32>]
-    print("raw attr:", raw)
-    # CHECK: - 42
-    # CHECK: - true
-    # CHECK: - vector<4xf32>
-    for attr in ArrayAttr(raw):
-        print("- ", attr)
+        arr = Attribute.parse("[42, true, vector<4xf32>]")
+    # CHECK: arr: [42, true, vector<4xf32>]
+    print("arr:", arr)
+    # CHECK: - IntegerAttr(42 : i64)
+    # CHECK: - BoolAttr(true)
+    # CHECK: - TypeAttr(vector<4xf32>)
+    for attr in arr:
+        print("- ", repr(attr))
 
     with Context():
         intAttr = Attribute.parse("42")
@@ -504,18 +504,18 @@ def testArrayAttr():
         raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
     # CHECK: attr: [vector<4xf32>, true, 42]
     print("raw attr:", raw)
-    # CHECK: - vector<4xf32>
-    # CHECK: - true
-    # CHECK: - 42
-    arr = ArrayAttr(raw)
+    # CHECK: - TypeAttr(vector<4xf32>)
+    # CHECK: - BoolAttr(true
+    # CHECK: - IntegerAttr(42 : i64)
+    arr = raw
     for attr in arr:
-        print("- ", attr)
-    # CHECK: attr[0]: vector<4xf32>
-    print("attr[0]:", arr[0])
-    # CHECK: attr[1]: true
-    print("attr[1]:", arr[1])
-    # CHECK: attr[2]: 42
-    print("attr[2]:", arr[2])
+        print("- ", repr(attr))
+    # CHECK: attr[0]: TypeAttr(vector<4xf32>)
+    print("attr[0]:", repr(arr[0]))
+    # CHECK: attr[1]: BoolAttr(true)
+    print("attr[1]:", repr(arr[1]))
+    # CHECK: attr[2]: IntegerAttr(42 : i64)
+    print("attr[2]:", repr(arr[2]))
     try:
         print("attr[3]:", arr[3])
     except IndexError as e:

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 51a311dec94419..672418b5383ae4 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -329,11 +329,13 @@ def testRankedTensorType():
         else:
             print("Exception not produced")
 
+        tensor = RankedTensorType.get(shape, f32, StringAttr.get("encoding"))
+        assert tensor.shape == shape
+        assert tensor.encoding.value == "encoding"
+
         # Encoding should be None.
         assert RankedTensorType.get(shape, f32).encoding is None
 
-        tensor = RankedTensorType.get(shape, f32)
-        assert tensor.shape == shape
 
 
 # CHECK-LABEL: TEST: testUnrankedTensorType
@@ -388,12 +390,12 @@ def testMemRefType():
         memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
         # CHECK: memref type: memref<2x3xf32, 2>
         print("memref type:", memref_f32)
-        # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
-        print("memref layout:", memref_f32.layout)
+        # CHECK: memref layout: AffineMapAttr(affine_map<(d0, d1) -> (d0, d1)>)
+        print("memref layout:", repr(memref_f32.layout))
         # CHECK: memref affine map: (d0, d1) -> (d0, d1)
         print("memref affine map:", memref_f32.affine_map)
-        # CHECK: memory space: 2
-        print("memory space:", memref_f32.memory_space)
+        # CHECK: memory space: IntegerAttr(2 : i64)
+        print("memory space:", repr(memref_f32.memory_space))
 
         layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
         memref_layout = MemRefType.get(shape, f32, layout=layout)
@@ -403,7 +405,7 @@ def testMemRefType():
         print("memref layout:", memref_layout.layout)
         # CHECK: memref affine map: (d0, d1) -> (d1, d0)
         print("memref affine map:", memref_layout.affine_map)
-        # CHECK: memory space: <<NULL ATTRIBUTE>>
+        # CHECK: memory space: None
         print("memory space:", memref_layout.memory_space)
 
         none = NoneType.get()
@@ -428,6 +430,8 @@ def testUnrankedMemRefType():
         unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2"))
         # CHECK: unranked memref type: memref<*xf32, 2>
         print("unranked memref type:", unranked_memref)
+        # CHECK: memory space: IntegerAttr(2 : i64)
+        print("memory space:", repr(unranked_memref.memory_space))
         try:
             invalid_rank = unranked_memref.rank
         except ValueError as e:

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 639f8ff2b42551..9679b5846af190 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -532,9 +532,9 @@ def testOperationAttributes():
     )
     op = module.body.operations[0]
     assert len(op.attributes) == 3
-    iattr = IntegerAttr(op.attributes["some.attribute"])
-    fattr = FloatAttr(op.attributes["other.attribute"])
-    sattr = StringAttr(op.attributes["dependent"])
+    iattr = op.attributes["some.attribute"]
+    fattr = op.attributes["other.attribute"]
+    sattr = op.attributes["dependent"]
     # CHECK: Attribute type i8, value 1
     print(f"Attribute type {iattr.type}, value {iattr.value}")
     # CHECK: Attribute type f64, value 3.0

diff  --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 17f3e354bee2be..3264cfcf9a1049 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -75,6 +75,7 @@ def testSymbolTableInsert():
         updated_name = symbol_table.insert(foo2)
         assert foo2.name.value != "foo"
         assert foo2.name == updated_name
+        assert isinstance(updated_name, StringAttr)
 
         # CHECK: module
         # CHECK:   func private @foo()
@@ -112,10 +113,10 @@ def testSymbolTableRAUW():
         # CHECK: call @bam()
         # CHECK: func private @bam
         print(m)
-        # CHECK: Foo symbol: "foo"
-        # CHECK: Bar symbol: "bam"
-        print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}")
-        print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}")
+        # CHECK: Foo symbol: StringAttr("foo")
+        # CHECK: Bar symbol: StringAttr("bam")
+        print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
+        print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
 
 
 # CHECK-LABEL: testSymbolTableVisibility
@@ -130,8 +131,8 @@ def testSymbolTableVisibility():
       """
         )
         foo = m.operation.regions[0].blocks[0].operations[0]
-        # CHECK: Existing visibility: "private"
-        print(f"Existing visibility: {SymbolTable.get_visibility(foo)}")
+        # CHECK: Existing visibility: StringAttr("private")
+        print(f"Existing visibility: {repr(SymbolTable.get_visibility(foo))}")
         SymbolTable.set_visibility(foo, "public")
         # CHECK: func public @foo
         print(m)


        


More information about the Mlir-commits mailing list