[Mlir-commits] [mlir] 989b194 - [mlir][Python] Make DenseElementsAttr loading be int size agnostic.
Stella Laurenzo
llvmlistbot at llvm.org
Tue Nov 17 21:55:31 PST 2020
Author: Stella Laurenzo
Date: 2020-11-17T21:50:44-08:00
New Revision: 989b19442905b2a8aa83e1db65c1c5ab1211a27b
URL: https://github.com/llvm/llvm-project/commit/989b19442905b2a8aa83e1db65c1c5ab1211a27b
DIFF: https://github.com/llvm/llvm-project/commit/989b19442905b2a8aa83e1db65c1c5ab1211a27b.diff
LOG: [mlir][Python] Make DenseElementsAttr loading be int size agnostic.
* I had missed the note about "Standard size" in the docs. On Windows, the 'l' types are 32bit.
* This fixes the only failing MLIR-Python test on Windows.
Differential Revision: https://reviews.llvm.org/D91283
Added:
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 152f067ea636..7b5e341bc660 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1534,6 +1534,7 @@ class PyDenseElementsAttribute
MlirContext context = contextWrapper->get();
// Switch on the types that can be bulk loaded between the Python and
// MLIR-C APIs.
+ // See: https://docs.python.org/3/library/struct.html#format-characters
if (arrayInfo.format == "f") {
// f32
assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
@@ -1548,42 +1549,44 @@ class PyDenseElementsAttribute
contextWrapper->getRef(),
bulkLoad(context, mlirDenseElementsAttrDoubleGet,
mlirF64TypeGet(context), arrayInfo));
- } else if (arrayInfo.format == "i") {
- // i32
- assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
- MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrInt32Get,
- elementType, arrayInfo));
- } else if (arrayInfo.format == "I") {
- // unsigned i32
- assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
- MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrUInt32Get,
- elementType, arrayInfo));
- } else if (arrayInfo.format == "l") {
- // i64
- assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
- MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrInt64Get,
- elementType, arrayInfo));
- } else if (arrayInfo.format == "L") {
- // unsigned i64
- assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
- MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrUInt64Get,
- elementType, arrayInfo));
+ } else if (isSignedIntegerFormat(arrayInfo.format)) {
+ if (arrayInfo.itemsize == 4) {
+ // i32
+ MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
+ return PyDenseElementsAttribute(contextWrapper->getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrInt32Get,
+ elementType, arrayInfo));
+ } else if (arrayInfo.itemsize == 8) {
+ // i64
+ MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ return PyDenseElementsAttribute(contextWrapper->getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrInt64Get,
+ elementType, arrayInfo));
+ }
+ } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
+ if (arrayInfo.itemsize == 4) {
+ // unsigned i32
+ MlirType elementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
+ return PyDenseElementsAttribute(contextWrapper->getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrUInt32Get,
+ elementType, arrayInfo));
+ } else if (arrayInfo.itemsize == 8) {
+ // unsigned i64
+ MlirType elementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ return PyDenseElementsAttribute(contextWrapper->getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrUInt64Get,
+ elementType, arrayInfo));
+ }
}
// TODO: Fall back to string-based get.
@@ -1656,7 +1659,23 @@ class PyDenseElementsAttribute
const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
return ctor(shapedType, numElements, contents);
}
-};
+
+ static bool isUnsignedIntegerFormat(const std::string &format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
+ code == 'Q';
+ }
+
+ static bool isSignedIntegerFormat(const std::string &format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
+ code == 'q';
+ }
+}; // namespace
/// Refinement of the PyDenseElementsAttribute for attributes containing integer
/// (and boolean) values. Supports element access.
More information about the Mlir-commits
mailing list