[Mlir-commits] [mlir] a2c8aeb - [mlir][Python] Finish adding RankedTensorType support for encoding.
Stella Laurenzo
llvmlistbot at llvm.org
Mon May 10 13:43:18 PDT 2021
Author: Stella Laurenzo
Date: 2021-05-10T20:39:16Z
New Revision: a2c8aebd8f8f81ba0af1c50580036faf73e8e2dc
URL: https://github.com/llvm/llvm-project/commit/a2c8aebd8f8f81ba0af1c50580036faf73e8e2dc
DIFF: https://github.com/llvm/llvm-project/commit/a2c8aebd8f8f81ba0af1c50580036faf73e8e2dc.diff
LOG: [mlir][Python] Finish adding RankedTensorType support for encoding.
Differential Revision: https://reviews.llvm.org/D102184
Added:
Modified:
mlir/include/mlir-c/BuiltinTypes.h
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/test/CAPI/ir.c
mlir/test/python/dialects/sparse_tensor/dialect.py
mlir/test/python/ir/builtin_types.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 7d45452af5f69..a677d4d365b11 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -203,6 +203,10 @@ MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(
MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType,
MlirAttribute encoding);
+/// Gets the 'encoding' attribute from the ranked tensor type, returning a null
+/// attribute if none.
+MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type);
+
/// Creates an unranked tensor type with the given element type in the same
/// context as the element type. The type is owned by the context.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGet(MlirType elementType);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index b6875c76e09c5..568cca160a595 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -338,10 +338,11 @@ class PyRankedTensorType
c.def_static(
"get",
[](std::vector<int64_t> shape, PyType &elementType,
+ llvm::Optional<PyAttribute> &encodingAttr,
DefaultingPyLocation loc) {
- MlirAttribute encodingAttr = mlirAttributeGetNull();
MlirType t = mlirRankedTensorTypeGetChecked(
- loc, shape.size(), shape.data(), elementType, encodingAttr);
+ loc, shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@@ -355,8 +356,17 @@ class PyRankedTensorType
}
return PyRankedTensorType(elementType.getContext(), t);
},
- py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
+ py::arg("shape"), py::arg("element_type"),
+ py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
"Create a ranked tensor type");
+ c.def_property_readonly(
+ "encoding",
+ [](PyRankedTensorType &self) -> llvm::Optional<PyAttribute> {
+ MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
+ if (mlirAttributeIsNull(encoding))
+ return llvm::None;
+ return PyAttribute(self.getContext(), encoding);
+ });
}
};
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 1e5fa8a32023b..d978f17b98d5b 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -206,6 +206,10 @@ MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
unwrap(elementType), unwrap(encoding)));
}
+MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
+ return wrap(unwrap(type).cast<RankedTensorType>().getEncoding());
+}
+
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
return wrap(UnrankedTensorType::get(unwrap(elementType)));
}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index cb9aa5de523ec..7176cbb2625f8 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -690,7 +690,8 @@ static int printBuiltinTypes(MlirContext ctx) {
MlirType rankedTensor = mlirRankedTensorTypeGet(
sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
if (!mlirTypeIsATensor(rankedTensor) ||
- !mlirTypeIsARankedTensor(rankedTensor))
+ !mlirTypeIsARankedTensor(rankedTensor) ||
+ !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
return 16;
mlirTypeDump(rankedTensor);
fprintf(stderr, "\n");
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 9b8c66c327458..581f5eab250cf 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -73,3 +73,18 @@ def testEncodingAttr2D():
print(created)
# CHECK: created_equal: True
print(f"created_equal: {created == casted}")
+
+
+# CHECK-LABEL: TEST: testEncodingAttrOnTensor
+ at run
+def testEncodingAttrOnTensor():
+ with Context() as ctx, Location.unknown():
+ encoding = st.EncodingAttr(Attribute.parse(
+ '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
+ 'pointerBitWidth = 16, indexBitWidth = 32 }>'))
+ tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
+ # CHECK: tensor<1024xf32, #sparse_tensor
+ print(tt)
+ # CHECK: #sparse_tensor.encoding
+ print(tt.encoding)
+ assert tt.encoding == encoding
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index a2cc2da894973..053e2ef3423ed 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -294,6 +294,9 @@ def testRankedTensorType():
else:
print("Exception not produced")
+ # Encoding should be None.
+ assert RankedTensorType.get(shape, f32).encoding is None
+
# CHECK-LABEL: TEST: testUnrankedTensorType
@run
More information about the Mlir-commits
mailing list