[Mlir-commits] [mlir] 34f72d9 - [mlir][python] expose the shape property of shaped types

Alex Zinenko llvmlistbot at llvm.org
Wed Nov 3 02:49:17 PDT 2021


Author: Alex Zinenko
Date: 2021-11-03T10:49:12+01:00
New Revision: 34f72d91252b92e956b80b97e0d586e1ddce5221

URL: https://github.com/llvm/llvm-project/commit/34f72d91252b92e956b80b97e0d586e1ddce5221
DIFF: https://github.com/llvm/llvm-project/commit/34f72d91252b92e956b80b97e0d586e1ddce5221.diff

LOG: [mlir][python] expose the shape property of shaped types

This has been missing in the original definition of shaped types.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D113025

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/test/python/ir/builtin_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 1cfd799bf693..89fdb1f06a91 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -284,6 +284,19 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
         },
         "Returns whether the given value is used as a placeholder for dynamic "
         "strides and offsets in shaped types.");
+    c.def_property_readonly(
+        "shape",
+        [](PyShapedType &self) {
+          self.requireHasRank();
+
+          std::vector<int64_t> shape;
+          int64_t rank = mlirShapedTypeGetRank(self);
+          shape.reserve(rank);
+          for (int64_t i = 0; i < rank; ++i)
+            shape.push_back(mlirShapedTypeGetDimSize(self, i));
+          return shape;
+        },
+        "Returns the shape of the ranked shaped type as a list of integers.");
   }
 
 private:

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index c5b32e8ea018..7d881b90f0fb 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -315,6 +315,9 @@ def testRankedTensorType():
     # Encoding should be None.
     assert RankedTensorType.get(shape, f32).encoding is None
 
+    tensor = RankedTensorType.get(shape, f32)
+    assert tensor.shape == shape
+
 
 # CHECK-LABEL: TEST: testUnrankedTensorType
 @run
@@ -396,6 +399,8 @@ def testMemRefType():
     else:
       print("Exception not produced")
 
+    assert memref.shape == shape
+
 
 # CHECK-LABEL: TEST: testUnrankedMemRefType
 @run


        


More information about the Mlir-commits mailing list