[Mlir-commits] [mlir] 34f72d9 - [mlir][python] expose the shape property of shaped types
Alex Zinenko
llvmlistbot at llvm.org
Wed Nov 3 02:49:17 PDT 2021
Author: Alex Zinenko
Date: 2021-11-03T10:49:12+01:00
New Revision: 34f72d91252b92e956b80b97e0d586e1ddce5221
URL: https://github.com/llvm/llvm-project/commit/34f72d91252b92e956b80b97e0d586e1ddce5221
DIFF: https://github.com/llvm/llvm-project/commit/34f72d91252b92e956b80b97e0d586e1ddce5221.diff
LOG: [mlir][python] expose the shape property of shaped types
This has been missing in the original definition of shaped types.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D113025
Added:
Modified:
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/test/python/ir/builtin_types.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 1cfd799bf693..89fdb1f06a91 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -284,6 +284,19 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
},
"Returns whether the given value is used as a placeholder for dynamic "
"strides and offsets in shaped types.");
+ c.def_property_readonly(
+ "shape",
+ [](PyShapedType &self) {
+ self.requireHasRank();
+
+ std::vector<int64_t> shape;
+ int64_t rank = mlirShapedTypeGetRank(self);
+ shape.reserve(rank);
+ for (int64_t i = 0; i < rank; ++i)
+ shape.push_back(mlirShapedTypeGetDimSize(self, i));
+ return shape;
+ },
+ "Returns the shape of the ranked shaped type as a list of integers.");
}
private:
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index c5b32e8ea018..7d881b90f0fb 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -315,6 +315,9 @@ def testRankedTensorType():
# Encoding should be None.
assert RankedTensorType.get(shape, f32).encoding is None
+ tensor = RankedTensorType.get(shape, f32)
+ assert tensor.shape == shape
+
# CHECK-LABEL: TEST: testUnrankedTensorType
@run
@@ -396,6 +399,8 @@ def testMemRefType():
else:
print("Exception not produced")
+ assert memref.shape == shape
+
# CHECK-LABEL: TEST: testUnrankedMemRefType
@run
More information about the Mlir-commits
mailing list