[Mlir-commits] [mlir] 974c159 - [mlir][python] Downcast attributes in more places
Rahul Kayaith
llvmlistbot at llvm.org
Mon Jul 10 19:01:40 PDT 2023
Author: Rahul Kayaith
Date: 2023-07-10T22:01:34-04:00
New Revision: 974c1596abdea379ec468bf14c3681fad7a53987
URL: https://github.com/llvm/llvm-project/commit/974c1596abdea379ec468bf14c3681fad7a53987
DIFF: https://github.com/llvm/llvm-project/commit/974c1596abdea379ec468bf14c3681fad7a53987.diff
LOG: [mlir][python] Downcast attributes in more places
Update remaining `PyAttribute`-returning APIs to return `MlirAttribute` instead,
so that they go through the downcasting mechanism.
Reviewed By: makslevental
Differential Revision: https://reviews.llvm.org/D154462
Added:
Modified:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/test/python/ir/array_attributes.py
mlir/test/python/ir/attributes.py
mlir/test/python/ir/builtin_types.py
mlir/test/python/ir/operation.py
mlir/test/python/ir/symbol_table.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 4ee06fa7a6d751..84a48a890eb409 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -270,12 +270,11 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
PyArrayAttributeIterator &dunderIter() { return *this; }
- PyAttribute dunderNext() {
+ MlirAttribute dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
throw py::stop_iteration();
- return PyAttribute(attr.getContext(),
- mlirArrayAttrGetElement(attr.get(), nextIndex++));
+ return mlirArrayAttrGetElement(attr.get(), nextIndex++);
}
static void bind(py::module &m) {
@@ -290,8 +289,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
int nextIndex = 0;
};
- PyAttribute getItem(intptr_t i) {
- return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
+ MlirAttribute getItem(intptr_t i) {
+ return mlirArrayAttrGetElement(*this, i);
}
static void bindDerived(ClassTy &c) {
@@ -843,13 +842,11 @@ class PyDenseElementsAttribute
return mlirDenseElementsAttrIsSplat(self);
})
.def("get_splat_value",
- [](PyDenseElementsAttribute &self) -> PyAttribute {
- if (!mlirDenseElementsAttrIsSplat(self)) {
+ [](PyDenseElementsAttribute &self) {
+ if (!mlirDenseElementsAttrIsSplat(self))
throw py::value_error(
"get_splat_value called on a non-splat attribute");
- }
- return PyAttribute(self.getContext(),
- mlirDenseElementsAttrGetSplatValue(self));
+ return mlirDenseElementsAttrGetSplatValue(self);
})
.def_buffer(&PyDenseElementsAttribute::accessBuffer);
}
@@ -1018,10 +1015,9 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
MlirAttribute attr =
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
- if (mlirAttributeIsNull(attr)) {
+ if (mlirAttributeIsNull(attr))
throw py::key_error("attempt to access a non-existent attribute");
- }
- return PyAttribute(self.getContext(), attr);
+ return attr;
});
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
if (index < 0 || index >= self.dunderLen()) {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 3ab6d57b41690d..6c0b4a0604e31d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1908,19 +1908,17 @@ void PySymbolTable::dunderDel(const std::string &name) {
erase(py::cast<PyOperationBase &>(operation));
}
-PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
+MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
operation->checkValid();
symbol.getOperation().checkValid();
MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
if (mlirAttributeIsNull(symbolAttr))
throw py::value_error("Expected operation to have a symbol name.");
- return PyAttribute(
- symbol.getOperation().getContext(),
- mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
+ return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
}
-PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
+MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
// Op must already be a symbol.
PyOperation &operation = symbol.getOperation();
operation.checkValid();
@@ -1929,7 +1927,7 @@ PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
mlirOperationGetAttributeByName(operation.get(), attrName);
if (mlirAttributeIsNull(existingNameAttr))
throw py::value_error("Expected operation to have a symbol name.");
- return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
+ return existingNameAttr;
}
void PySymbolTable::setSymbolName(PyOperationBase &symbol,
@@ -1947,7 +1945,7 @@ void PySymbolTable::setSymbolName(PyOperationBase &symbol,
mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
}
-PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
+MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
PyOperation &operation = symbol.getOperation();
operation.checkValid();
MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
@@ -1955,7 +1953,7 @@ PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
mlirOperationGetAttributeByName(operation.get(), attrName);
if (mlirAttributeIsNull(existingVisAttr))
throw py::value_error("Expected operation to have a symbol visibility.");
- return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
+ return existingVisAttr;
}
void PySymbolTable::setVisibility(PyOperationBase &symbol,
@@ -2287,13 +2285,13 @@ class PyOpAttributeMap {
PyOpAttributeMap(PyOperationRef operation)
: operation(std::move(operation)) {}
- PyAttribute dunderGetItemNamed(const std::string &name) {
+ MlirAttribute dunderGetItemNamed(const std::string &name) {
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
throw py::key_error("attempt to access a non-existent attribute");
}
- return PyAttribute(operation->getContext(), attr);
+ return attr;
}
PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 225580f0f45759..76acfe5e7790be 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -1174,14 +1174,14 @@ class PySymbolTable {
/// Inserts the given operation into the symbol table. The operation must have
/// the symbol trait.
- PyAttribute insert(PyOperationBase &symbol);
+ MlirAttribute insert(PyOperationBase &symbol);
/// Gets and sets the name of a symbol op.
- static PyAttribute getSymbolName(PyOperationBase &symbol);
+ static MlirAttribute getSymbolName(PyOperationBase &symbol);
static void setSymbolName(PyOperationBase &symbol, const std::string &name);
/// Gets and sets the visibility of a symbol op.
- static PyAttribute getVisibility(PyOperationBase &symbol);
+ static MlirAttribute getVisibility(PyOperationBase &symbol);
static void setVisibility(PyOperationBase &symbol,
const std::string &visibility);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index caf215be85baa8..a7ccfbea542f5c 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -505,11 +505,12 @@ class PyRankedTensorType
py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
"Create a ranked tensor type");
c.def_property_readonly(
- "encoding", [](PyRankedTensorType &self) -> std::optional<PyAttribute> {
+ "encoding",
+ [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
if (mlirAttributeIsNull(encoding))
return std::nullopt;
- return PyAttribute(self.getContext(), encoding);
+ return encoding;
});
}
};
@@ -570,9 +571,8 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
py::arg("loc") = py::none(), "Create a memref type")
.def_property_readonly(
"layout",
- [](PyMemRefType &self) -> PyAttribute {
- MlirAttribute layout = mlirMemRefTypeGetLayout(self);
- return PyAttribute(self.getContext(), layout);
+ [](PyMemRefType &self) -> MlirAttribute {
+ return mlirMemRefTypeGetLayout(self);
},
"The layout of the MemRef type.")
.def_property_readonly(
@@ -584,9 +584,11 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
"The layout of the MemRef type as an affine map.")
.def_property_readonly(
"memory_space",
- [](PyMemRefType &self) -> PyAttribute {
+ [](PyMemRefType &self) -> std::optional<MlirAttribute> {
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
- return PyAttribute(self.getContext(), a);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return a;
},
"Returns the memory space of the given MemRef type.");
}
@@ -622,9 +624,11 @@ class PyUnrankedMemRefType
py::arg("loc") = py::none(), "Create a unranked memref type")
.def_property_readonly(
"memory_space",
- [](PyUnrankedMemRefType &self) -> PyAttribute {
- MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
- return PyAttribute(self.getContext(), a);
+ [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
+ MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return a;
},
"Returns the memory space of the given Unranked MemRef type.");
}
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 5ce8bc66fcf96e..b592804013b545 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -47,7 +47,11 @@ def testGetDenseElementsSplatInt():
print(attr)
# CHECK: is_splat: True
print("is_splat:", attr.is_splat)
- assert attr.get_splat_value() == element
+
+ # CHECK: splat_value: IntegerAttr(555 : i32)
+ splat_value = attr.get_splat_value()
+ print("splat_value:", repr(splat_value))
+ assert splat_value == element
# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 28729e86ccd4c0..d986cac17dd765 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -441,11 +441,11 @@ def testDictAttr():
assert len(a) == 2
- # CHECK: 42 : i32
- print(a["integerattr"])
+ # CHECK: integerattr: IntegerAttr(42 : i32)
+ print("integerattr:", repr(a["integerattr"]))
- # CHECK: "string"
- print(a["stringattr"])
+ # CHECK: stringattr: StringAttr("string")
+ print("stringattr:", repr(a["stringattr"]))
# CHECK: True
print("stringattr" in a)
@@ -488,14 +488,14 @@ def testTypeAttr():
@run
def testArrayAttr():
with Context():
- raw = Attribute.parse("[42, true, vector<4xf32>]")
- # CHECK: attr: [42, true, vector<4xf32>]
- print("raw attr:", raw)
- # CHECK: - 42
- # CHECK: - true
- # CHECK: - vector<4xf32>
- for attr in ArrayAttr(raw):
- print("- ", attr)
+ arr = Attribute.parse("[42, true, vector<4xf32>]")
+ # CHECK: arr: [42, true, vector<4xf32>]
+ print("arr:", arr)
+ # CHECK: - IntegerAttr(42 : i64)
+ # CHECK: - BoolAttr(true)
+ # CHECK: - TypeAttr(vector<4xf32>)
+ for attr in arr:
+ print("- ", repr(attr))
with Context():
intAttr = Attribute.parse("42")
@@ -504,18 +504,18 @@ def testArrayAttr():
raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
# CHECK: attr: [vector<4xf32>, true, 42]
print("raw attr:", raw)
- # CHECK: - vector<4xf32>
- # CHECK: - true
- # CHECK: - 42
- arr = ArrayAttr(raw)
+ # CHECK: - TypeAttr(vector<4xf32>)
+ # CHECK: - BoolAttr(true
+ # CHECK: - IntegerAttr(42 : i64)
+ arr = raw
for attr in arr:
- print("- ", attr)
- # CHECK: attr[0]: vector<4xf32>
- print("attr[0]:", arr[0])
- # CHECK: attr[1]: true
- print("attr[1]:", arr[1])
- # CHECK: attr[2]: 42
- print("attr[2]:", arr[2])
+ print("- ", repr(attr))
+ # CHECK: attr[0]: TypeAttr(vector<4xf32>)
+ print("attr[0]:", repr(arr[0]))
+ # CHECK: attr[1]: BoolAttr(true)
+ print("attr[1]:", repr(arr[1]))
+ # CHECK: attr[2]: IntegerAttr(42 : i64)
+ print("attr[2]:", repr(arr[2]))
try:
print("attr[3]:", arr[3])
except IndexError as e:
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 51a311dec94419..672418b5383ae4 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -329,11 +329,13 @@ def testRankedTensorType():
else:
print("Exception not produced")
+ tensor = RankedTensorType.get(shape, f32, StringAttr.get("encoding"))
+ assert tensor.shape == shape
+ assert tensor.encoding.value == "encoding"
+
# Encoding should be None.
assert RankedTensorType.get(shape, f32).encoding is None
- tensor = RankedTensorType.get(shape, f32)
- assert tensor.shape == shape
# CHECK-LABEL: TEST: testUnrankedTensorType
@@ -388,12 +390,12 @@ def testMemRefType():
memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
# CHECK: memref type: memref<2x3xf32, 2>
print("memref type:", memref_f32)
- # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
- print("memref layout:", memref_f32.layout)
+ # CHECK: memref layout: AffineMapAttr(affine_map<(d0, d1) -> (d0, d1)>)
+ print("memref layout:", repr(memref_f32.layout))
# CHECK: memref affine map: (d0, d1) -> (d0, d1)
print("memref affine map:", memref_f32.affine_map)
- # CHECK: memory space: 2
- print("memory space:", memref_f32.memory_space)
+ # CHECK: memory space: IntegerAttr(2 : i64)
+ print("memory space:", repr(memref_f32.memory_space))
layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
memref_layout = MemRefType.get(shape, f32, layout=layout)
@@ -403,7 +405,7 @@ def testMemRefType():
print("memref layout:", memref_layout.layout)
# CHECK: memref affine map: (d0, d1) -> (d1, d0)
print("memref affine map:", memref_layout.affine_map)
- # CHECK: memory space: <<NULL ATTRIBUTE>>
+ # CHECK: memory space: None
print("memory space:", memref_layout.memory_space)
none = NoneType.get()
@@ -428,6 +430,8 @@ def testUnrankedMemRefType():
unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2"))
# CHECK: unranked memref type: memref<*xf32, 2>
print("unranked memref type:", unranked_memref)
+ # CHECK: memory space: IntegerAttr(2 : i64)
+ print("memory space:", repr(unranked_memref.memory_space))
try:
invalid_rank = unranked_memref.rank
except ValueError as e:
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 639f8ff2b42551..9679b5846af190 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -532,9 +532,9 @@ def testOperationAttributes():
)
op = module.body.operations[0]
assert len(op.attributes) == 3
- iattr = IntegerAttr(op.attributes["some.attribute"])
- fattr = FloatAttr(op.attributes["other.attribute"])
- sattr = StringAttr(op.attributes["dependent"])
+ iattr = op.attributes["some.attribute"]
+ fattr = op.attributes["other.attribute"]
+ sattr = op.attributes["dependent"]
# CHECK: Attribute type i8, value 1
print(f"Attribute type {iattr.type}, value {iattr.value}")
# CHECK: Attribute type f64, value 3.0
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 17f3e354bee2be..3264cfcf9a1049 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -75,6 +75,7 @@ def testSymbolTableInsert():
updated_name = symbol_table.insert(foo2)
assert foo2.name.value != "foo"
assert foo2.name == updated_name
+ assert isinstance(updated_name, StringAttr)
# CHECK: module
# CHECK: func private @foo()
@@ -112,10 +113,10 @@ def testSymbolTableRAUW():
# CHECK: call @bam()
# CHECK: func private @bam
print(m)
- # CHECK: Foo symbol: "foo"
- # CHECK: Bar symbol: "bam"
- print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}")
- print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}")
+ # CHECK: Foo symbol: StringAttr("foo")
+ # CHECK: Bar symbol: StringAttr("bam")
+ print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
+ print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
# CHECK-LABEL: testSymbolTableVisibility
@@ -130,8 +131,8 @@ def testSymbolTableVisibility():
"""
)
foo = m.operation.regions[0].blocks[0].operations[0]
- # CHECK: Existing visibility: "private"
- print(f"Existing visibility: {SymbolTable.get_visibility(foo)}")
+ # CHECK: Existing visibility: StringAttr("private")
+ print(f"Existing visibility: {repr(SymbolTable.get_visibility(foo))}")
SymbolTable.set_visibility(foo, "public")
# CHECK: func public @foo
print(m)
More information about the Mlir-commits
mailing list