[Mlir-commits] [mlir] 99dee31 - Make it possible to create DenseElementsAttrs with arbitrary shaped types in Python bindings
Jacques Pienaar
llvmlistbot at llvm.org
Wed Mar 8 11:12:28 PST 2023
Author: Adam Paszke
Date: 2023-03-08T11:11:45-08:00
New Revision: 99dee31ef48012f8984ddab806b7345c24b02a72
URL: https://github.com/llvm/llvm-project/commit/99dee31ef48012f8984ddab806b7345c24b02a72
DIFF: https://github.com/llvm/llvm-project/commit/99dee31ef48012f8984ddab806b7345c24b02a72.diff
LOG: Make it possible to create DenseElementsAttrs with arbitrary shaped types in Python bindings
Right now the bindings assume that all DenseElementsAttrs correspond to tensor values,
making it impossible to create vector-typed constants. I didn't want to change the API
significantly, so I opted for reusing the current signature of `.get`. Its `type` argument
now accepts both element types (in which case `shape` and `signless` can be specified too),
or a shaped type, which specifies the full type of the created attr (`shape` cannot be specified
in that case).
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D145053
Added:
Modified:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/test/python/dialects/builtin.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b0c35ffb8a53f..c59a54b6699a7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -624,8 +624,17 @@ class PyDenseElementsAttribute
}
}
if (bulkLoadElementType) {
- auto shapedType = mlirRankedTensorTypeGet(
- shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
+ MlirType shapedType;
+ if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+ if (explicitShape) {
+ throw std::invalid_argument("Shape can only be specified explicitly "
+ "when the type is not a shaped type.");
+ }
+ shapedType = *bulkLoadElementType;
+ } else {
+ shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
+ }
size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
shapedType, rawBufferSize, arrayInfo.ptr);
diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 94e29892ba7b1..eab24b5c796b0 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -3,6 +3,7 @@
from mlir.ir import *
import mlir.dialects.builtin as builtin
import mlir.dialects.func as func
+import numpy as np
def run(f):
@@ -221,3 +222,17 @@ def testFuncArgumentAccess():
# CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
# CHECK: %{{.*}}: f32)
print(module)
+
+
+# CHECK-LABEL: testDenseElementsAttr
+ at run
+def testDenseElementsAttr():
+ with Context(), Location.unknown():
+ values = np.arange(4, dtype=np.int32)
+ i32 = IntegerType.get_signless(32)
+ print(DenseElementsAttr.get(values, type=i32))
+ # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
+ print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
+ # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+ print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
+ # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
More information about the Mlir-commits
mailing list