[Mlir-commits] [mlir] ef1b735 - [MLIR][python bindings] Add support for DenseElementsAttr of IndexType
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 3 16:56:53 PDT 2023
Author: max
Date: 2023-05-03T18:45:40-05:00
New Revision: ef1b735dfbed00a11a34bf4169a5b3fd8816b52f
URL: https://github.com/llvm/llvm-project/commit/ef1b735dfbed00a11a34bf4169a5b3fd8816b52f
DIFF: https://github.com/llvm/llvm-project/commit/ef1b735dfbed00a11a34bf4169a5b3fd8816b52f.diff
LOG: [MLIR][python bindings] Add support for DenseElementsAttr of IndexType
Differential Revision: https://reviews.llvm.org/D149690
Added:
Modified:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/test/python/ir/array_attributes.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 5e7138b21c752..22001957ffa29 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -710,6 +710,10 @@ class PyDenseElementsAttribute
// f16
return bufferInfo<uint16_t>(shapedType, "e");
}
+ if (mlirTypeIsAIndex(elementType)) {
+ // Same as IndexType::kInternalStorageBitWidth
+ return bufferInfo<int64_t>(shapedType);
+ }
if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 32) {
if (mlirIntegerTypeIsSignless(elementType) ||
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 36b0769b20653..3de4edb884157 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -365,3 +365,20 @@ def testGetDenseElementsUI64():
# CHECK: {{\[}}4 5 6]]
print(np.array(attr))
+
+# CHECK-LABEL: TEST: testGetDenseElementsIndex
+ at run
+def testGetDenseElementsIndex():
+ with Context(), Location.unknown():
+ idx_type = IndexType.get()
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+ attr = DenseElementsAttr.get(array, type=idx_type)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex>
+ print(attr)
+ arr = np.array(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(arr)
+ # CHECK: True
+ print(arr.dtype == np.int64)
+
More information about the Mlir-commits
mailing list