[Mlir-commits] [mlir] a2c8aeb - [mlir][Python] Finish adding RankedTensorType support for encoding.

Stella Laurenzo llvmlistbot at llvm.org
Mon May 10 13:43:18 PDT 2021


Author: Stella Laurenzo
Date: 2021-05-10T20:39:16Z
New Revision: a2c8aebd8f8f81ba0af1c50580036faf73e8e2dc

URL: https://github.com/llvm/llvm-project/commit/a2c8aebd8f8f81ba0af1c50580036faf73e8e2dc
DIFF: https://github.com/llvm/llvm-project/commit/a2c8aebd8f8f81ba0af1c50580036faf73e8e2dc.diff

LOG: [mlir][Python] Finish adding RankedTensorType support for encoding.

Differential Revision: https://reviews.llvm.org/D102184

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/test/CAPI/ir.c
    mlir/test/python/dialects/sparse_tensor/dialect.py
    mlir/test/python/ir/builtin_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 7d45452af5f69..a677d4d365b11 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -203,6 +203,10 @@ MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(
     MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType,
     MlirAttribute encoding);
 
+/// Gets the 'encoding' attribute from the ranked tensor type, returning a null
+/// attribute if none.
+MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type);
+
 /// Creates an unranked tensor type with the given element type in the same
 /// context as the element type. The type is owned by the context.
 MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGet(MlirType elementType);

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index b6875c76e09c5..568cca160a595 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -338,10 +338,11 @@ class PyRankedTensorType
     c.def_static(
         "get",
         [](std::vector<int64_t> shape, PyType &elementType,
+           llvm::Optional<PyAttribute> &encodingAttr,
            DefaultingPyLocation loc) {
-          MlirAttribute encodingAttr = mlirAttributeGetNull();
           MlirType t = mlirRankedTensorTypeGetChecked(
-              loc, shape.size(), shape.data(), elementType, encodingAttr);
+              loc, shape.size(), shape.data(), elementType,
+              encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirTypeIsNull(t)) {
@@ -355,8 +356,17 @@ class PyRankedTensorType
           }
           return PyRankedTensorType(elementType.getContext(), t);
         },
-        py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
+        py::arg("shape"), py::arg("element_type"),
+        py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
         "Create a ranked tensor type");
+    c.def_property_readonly(
+        "encoding",
+        [](PyRankedTensorType &self) -> llvm::Optional<PyAttribute> {
+          MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
+          if (mlirAttributeIsNull(encoding))
+            return llvm::None;
+          return PyAttribute(self.getContext(), encoding);
+        });
   }
 };
 

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 1e5fa8a32023b..d978f17b98d5b 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -206,6 +206,10 @@ MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
       unwrap(elementType), unwrap(encoding)));
 }
 
+MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
+  return wrap(unwrap(type).cast<RankedTensorType>().getEncoding());
+}
+
 MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
   return wrap(UnrankedTensorType::get(unwrap(elementType)));
 }

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index cb9aa5de523ec..7176cbb2625f8 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -690,7 +690,8 @@ static int printBuiltinTypes(MlirContext ctx) {
   MlirType rankedTensor = mlirRankedTensorTypeGet(
       sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
   if (!mlirTypeIsATensor(rankedTensor) ||
-      !mlirTypeIsARankedTensor(rankedTensor))
+      !mlirTypeIsARankedTensor(rankedTensor) ||
+      !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
     return 16;
   mlirTypeDump(rankedTensor);
   fprintf(stderr, "\n");

diff  --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 9b8c66c327458..581f5eab250cf 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -73,3 +73,18 @@ def testEncodingAttr2D():
     print(created)
     # CHECK: created_equal: True
     print(f"created_equal: {created == casted}")
+
+
+# CHECK-LABEL: TEST: testEncodingAttrOnTensor
+ at run
+def testEncodingAttrOnTensor():
+  with Context() as ctx, Location.unknown():
+    encoding = st.EncodingAttr(Attribute.parse(
+      '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
+      'pointerBitWidth = 16, indexBitWidth = 32 }>'))
+    tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
+    # CHECK: tensor<1024xf32, #sparse_tensor
+    print(tt)
+    # CHECK: #sparse_tensor.encoding
+    print(tt.encoding)
+    assert tt.encoding == encoding

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index a2cc2da894973..053e2ef3423ed 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -294,6 +294,9 @@ def testRankedTensorType():
     else:
       print("Exception not produced")
 
+    # Encoding should be None.
+    assert RankedTensorType.get(shape, f32).encoding is None
+
 
 # CHECK-LABEL: TEST: testUnrankedTensorType
 @run


        


More information about the Mlir-commits mailing list