[Mlir-commits] [mlir] 989b194 - [mlir][Python] Make DenseElementsAttr loading be int size agnostic.

Stella Laurenzo llvmlistbot at llvm.org
Tue Nov 17 21:55:31 PST 2020


Author: Stella Laurenzo
Date: 2020-11-17T21:50:44-08:00
New Revision: 989b19442905b2a8aa83e1db65c1c5ab1211a27b

URL: https://github.com/llvm/llvm-project/commit/989b19442905b2a8aa83e1db65c1c5ab1211a27b
DIFF: https://github.com/llvm/llvm-project/commit/989b19442905b2a8aa83e1db65c1c5ab1211a27b.diff

LOG: [mlir][Python] Make DenseElementsAttr loading be int size agnostic.

* I had missed the note about "Standard size" in the docs. On Windows, the 'l' types are 32bit.
* This fixes the only failing MLIR-Python test on Windows.

Differential Revision: https://reviews.llvm.org/D91283

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 152f067ea636..7b5e341bc660 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1534,6 +1534,7 @@ class PyDenseElementsAttribute
     MlirContext context = contextWrapper->get();
     // Switch on the types that can be bulk loaded between the Python and
     // MLIR-C APIs.
+    // See: https://docs.python.org/3/library/struct.html#format-characters
     if (arrayInfo.format == "f") {
       // f32
       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
@@ -1548,42 +1549,44 @@ class PyDenseElementsAttribute
           contextWrapper->getRef(),
           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
                    mlirF64TypeGet(context), arrayInfo));
-    } else if (arrayInfo.format == "i") {
-      // i32
-      assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
-      MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
-                                      : mlirIntegerTypeSignedGet(context, 32);
-      return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                      bulkLoad(context,
-                                               mlirDenseElementsAttrInt32Get,
-                                               elementType, arrayInfo));
-    } else if (arrayInfo.format == "I") {
-      // unsigned i32
-      assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
-      MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
-                                      : mlirIntegerTypeUnsignedGet(context, 32);
-      return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                      bulkLoad(context,
-                                               mlirDenseElementsAttrUInt32Get,
-                                               elementType, arrayInfo));
-    } else if (arrayInfo.format == "l") {
-      // i64
-      assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
-      MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
-                                      : mlirIntegerTypeSignedGet(context, 64);
-      return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                      bulkLoad(context,
-                                               mlirDenseElementsAttrInt64Get,
-                                               elementType, arrayInfo));
-    } else if (arrayInfo.format == "L") {
-      // unsigned i64
-      assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
-      MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
-                                      : mlirIntegerTypeUnsignedGet(context, 64);
-      return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                      bulkLoad(context,
-                                               mlirDenseElementsAttrUInt64Get,
-                                               elementType, arrayInfo));
+    } else if (isSignedIntegerFormat(arrayInfo.format)) {
+      if (arrayInfo.itemsize == 4) {
+        // i32
+        MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
+                                        : mlirIntegerTypeSignedGet(context, 32);
+        return PyDenseElementsAttribute(contextWrapper->getRef(),
+                                        bulkLoad(context,
+                                                 mlirDenseElementsAttrInt32Get,
+                                                 elementType, arrayInfo));
+      } else if (arrayInfo.itemsize == 8) {
+        // i64
+        MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
+                                        : mlirIntegerTypeSignedGet(context, 64);
+        return PyDenseElementsAttribute(contextWrapper->getRef(),
+                                        bulkLoad(context,
+                                                 mlirDenseElementsAttrInt64Get,
+                                                 elementType, arrayInfo));
+      }
+    } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
+      if (arrayInfo.itemsize == 4) {
+        // unsigned i32
+        MlirType elementType = signless
+                                   ? mlirIntegerTypeGet(context, 32)
+                                   : mlirIntegerTypeUnsignedGet(context, 32);
+        return PyDenseElementsAttribute(contextWrapper->getRef(),
+                                        bulkLoad(context,
+                                                 mlirDenseElementsAttrUInt32Get,
+                                                 elementType, arrayInfo));
+      } else if (arrayInfo.itemsize == 8) {
+        // unsigned i64
+        MlirType elementType = signless
+                                   ? mlirIntegerTypeGet(context, 64)
+                                   : mlirIntegerTypeUnsignedGet(context, 64);
+        return PyDenseElementsAttribute(contextWrapper->getRef(),
+                                        bulkLoad(context,
+                                                 mlirDenseElementsAttrUInt64Get,
+                                                 elementType, arrayInfo));
+      }
     }
 
     // TODO: Fall back to string-based get.
@@ -1656,7 +1659,23 @@ class PyDenseElementsAttribute
     const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
     return ctor(shapedType, numElements, contents);
   }
-};
+
+  static bool isUnsignedIntegerFormat(const std::string &format) {
+    if (format.empty())
+      return false;
+    char code = format[0];
+    return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
+           code == 'Q';
+  }
+
+  static bool isSignedIntegerFormat(const std::string &format) {
+    if (format.empty())
+      return false;
+    char code = format[0];
+    return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
+           code == 'q';
+  }
+}; // namespace
 
 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
 /// (and boolean) values. Supports element access.


        


More information about the Mlir-commits mailing list