[Mlir-commits] [mlir] f66cd9e - [mlir] Add Python bindings for DenseResourceElementsAttr. (#66319)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 14 18:45:33 PDT 2023


Author: Stella Laurenzo
Date: 2023-09-14T18:45:29-07:00
New Revision: f66cd9e9556a53142a26a5c21a72e21f1579217c

URL: https://github.com/llvm/llvm-project/commit/f66cd9e9556a53142a26a5c21a72e21f1579217c
DIFF: https://github.com/llvm/llvm-project/commit/f66cd9e9556a53142a26a5c21a72e21f1579217c.diff

LOG: [mlir] Add Python bindings for DenseResourceElementsAttr. (#66319)

Only construction and type casting are implemented. The method to create
is explicitly named "unsafe" and the documentation calls out what the
caller is responsible for. There really isn't a better way to do this
and retain the power-user feature this represents.

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/test/CAPI/ir.c
    mlir/test/python/ir/array_attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 93c4ed5692ef26d..01d1b6008f5e215 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -558,6 +558,23 @@ mlirDenseElementsAttrGetRawData(MlirAttribute attr);
 // Resource blob attributes.
 //===----------------------------------------------------------------------===//
 
+MLIR_CAPI_EXPORTED bool
+mlirAttributeIsADenseResourceElements(MlirAttribute attr);
+
+/// Unlike the typed accessors below, constructs the attribute with a raw
+/// data buffer and no type/alignment checking. Use a more strongly typed
+/// accessor if possible. If dataIsMutable is false, then an immutable
+/// AsmResourceBlob will be created and that passed data contents will be
+/// treated as const.
+/// If the deleter is non NULL, then it will be called when the data buffer
+/// can no longer be accessed (passing userData to it).
+MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(
+    MlirType shapedType, MlirStringRef name, void *data, size_t dataLength,
+    size_t dataAlignment, bool dataIsMutable,
+    void (*deleter)(void *userData, const void *data, size_t size,
+                    size_t align),
+    void *userData);
+
 MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
     MlirType shapedType, MlirStringRef name, intptr_t numElements,
     const int *elements);
@@ -600,13 +617,6 @@ mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType,
                                                 intptr_t numElements,
                                                 const double *elements);
 
-/// Unlike the typed accessors above, constructs the attribute with a raw
-/// data buffer and no type/alignment checking. Use a more strongly typed
-/// accessor if possible.
-MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBlobResourceElementsAttrGet(
-    MlirType shapedType, MlirStringRef name, const void *data,
-    size_t dataLength);
-
 /// Returns the pos-th value (flat contiguous indexing) of a specific type
 /// contained by the given dense resource elements attribute.
 MLIR_CAPI_EXPORTED bool

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 105d2cecf20a193..94fa2527e40891e 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -72,6 +72,32 @@ or 255), then a splat will be created.
     type or if the buffer does not meet expectations.
 )";
 
+static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
+    R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
+
+This function does minimal validation or massaging of the data, and it is
+up to the caller to ensure that the buffer meets the characteristics
+implied by the shape.
+
+The backing buffer and any user objects will be retained for the lifetime
+of the resource blob. This is typically bounded to the context but the
+resource can have a shorter lifespan depending on how it is used in
+subsequent processing.
+
+Args:
+  buffer: The array or buffer to convert.
+  name: Name to provide to the resource (may be changed upon collision).
+  type: The explicit ShapedType to construct the attribute with.
+  context: Explicit context, if not from context manager.
+
+Returns:
+  DenseResourceElementsAttr on success.
+
+Raises:
+  ValueError: If the type of the buffer or array cannot be matched to an MLIR
+    type or if the buffer does not meet expectations.
+)";
+
 namespace {
 
 static MlirStringRef toMlirStringRef(const std::string &s) {
@@ -997,6 +1023,82 @@ class PyDenseIntElementsAttribute
   }
 };
 
+class PyDenseResourceElementsAttribute
+    : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction =
+      mlirAttributeIsADenseResourceElements;
+  static constexpr const char *pyClassName = "DenseResourceElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static PyDenseResourceElementsAttribute
+  getFromBuffer(py::buffer buffer, std::string name, PyType type,
+                std::optional<size_t> alignment, bool isMutable,
+                DefaultingPyMlirContext contextWrapper) {
+    if (!mlirTypeIsAShaped(type)) {
+      throw std::invalid_argument(
+          "Constructing a DenseResourceElementsAttr requires a ShapedType.");
+    }
+
+    // Do not request any conversions as we must ensure to use caller
+    // managed memory.
+    int flags = PyBUF_STRIDES;
+    std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
+    if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
+      throw py::error_already_set();
+    }
+
+    // This scope releaser will only release if we haven't yet transferred
+    // ownership.
+    auto freeBuffer = llvm::make_scope_exit([&]() {
+      if (view)
+        PyBuffer_Release(view.get());
+    });
+
+    if (!PyBuffer_IsContiguous(view.get(), 'A')) {
+      throw std::invalid_argument("Contiguous buffer is required.");
+    }
+
+    // Infer alignment to be the stride of one element if not explicit.
+    size_t inferredAlignment;
+    if (alignment)
+      inferredAlignment = *alignment;
+    else
+      inferredAlignment = view->strides[view->ndim - 1];
+
+    // The userData is a Py_buffer* that the deleter owns.
+    auto deleter = [](void *userData, const void *data, size_t size,
+                      size_t align) {
+      Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
+      PyBuffer_Release(ownedView);
+      delete ownedView;
+    };
+
+    size_t rawBufferSize = view->len;
+    MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
+        type, toMlirStringRef(name), view->buf, rawBufferSize,
+        inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
+    if (mlirAttributeIsNull(attr)) {
+      throw std::invalid_argument(
+          "DenseResourceElementsAttr could not be constructed from the given "
+          "buffer. "
+          "This may mean that the Python buffer layout does not match that "
+          "MLIR expected layout and is a bug.");
+    }
+    view.release();
+    return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get_from_buffer",
+                 PyDenseResourceElementsAttribute::getFromBuffer,
+                 py::arg("array"), py::arg("name"), py::arg("type"),
+                 py::arg("alignment") = py::none(),
+                 py::arg("is_mutable") = false, py::arg("context") = py::none(),
+                 kDenseResourceElementsAttrGetFromBufferDocstring);
+  }
+};
+
 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
@@ -1273,6 +1375,7 @@ void mlir::python::populateIRAttributes(py::module &m) {
   PyGlobals::get().registerTypeCaster(
       mlirDenseIntOrFPElementsAttrGetTypeID(),
       pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
+  PyDenseResourceElementsAttribute::bind(m);
 
   PyDictAttribute::bind(m);
   PySymbolRefAttribute::bind(m);

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 84a958d01d2eb14..b3066ee0c28bdc8 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -770,6 +770,30 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
 // Resource blob attributes.
 //===----------------------------------------------------------------------===//
 
+bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) {
+  return llvm::isa<DenseResourceElementsAttr>(unwrap(attr));
+}
+
+MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(
+    MlirType shapedType, MlirStringRef name, void *data, size_t dataLength,
+    size_t dataAlignment, bool dataIsMutable,
+    void (*deleter)(void *userData, const void *data, size_t size,
+                    size_t align),
+    void *userData) {
+  AsmResourceBlob::DeleterFn cppDeleter = {};
+  if (deleter) {
+    cppDeleter = [deleter, userData](void *data, size_t size, size_t align) {
+      deleter(userData, data, size, align);
+    };
+  }
+  AsmResourceBlob blob(
+      llvm::ArrayRef(static_cast<const char *>(data), dataLength),
+      dataAlignment, std::move(cppDeleter), dataIsMutable);
+  return wrap(
+      DenseResourceElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+                                     unwrap(name), std::move(blob)));
+}
+
 template <typename U, typename T>
 static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
                                       intptr_t numElements, const T *elements) {
@@ -778,139 +802,122 @@ static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
                          llvm::ArrayRef(elements, numElements))));
 }
 
-MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
+MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
     MlirType shapedType, MlirStringRef name, intptr_t numElements,
     const int *elements) {
   return getDenseResource<DenseBoolResourceElementsAttr>(shapedType, name,
                                                          numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
+MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
     MlirType shapedType, MlirStringRef name, intptr_t numElements,
     const uint8_t *elements) {
   return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
                                                         numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute
-mlirUnmanagedDenseUInt16ResourceElementsAttrGet(MlirType shapedType,
-                                                MlirStringRef name,
-                                                intptr_t numElements,
-                                                const uint16_t *elements) {
+MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet(
+    MlirType shapedType, MlirStringRef name, intptr_t numElements,
+    const uint16_t *elements) {
   return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
                                                          numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute
-mlirUnmanagedDenseUInt32ResourceElementsAttrGet(MlirType shapedType,
-                                                MlirStringRef name,
-                                                intptr_t numElements,
-                                                const uint32_t *elements) {
+MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet(
+    MlirType shapedType, MlirStringRef name, intptr_t numElements,
+    const uint32_t *elements) {
   return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
                                                          numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute
-mlirUnmanagedDenseUInt64ResourceElementsAttrGet(MlirType shapedType,
-                                                MlirStringRef name,
-                                                intptr_t numElements,
-                                                const uint64_t *elements) {
+MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet(
+    MlirType shapedType, MlirStringRef name, intptr_t numElements,
+    const uint64_t *elements) {
   return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
                                                          numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet(
+MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet(
     MlirType shapedType, MlirStringRef name, intptr_t numElements,
     const int8_t *elements) {
   return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
                                                         numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet(
+MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet(
     MlirType shapedType, MlirStringRef name, intptr_t numElements,
     const int16_t *elements) {
   return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
                                                          numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet(
+MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet(
     MlirType shapedType, MlirStringRef name, intptr_t numElements,
     const int32_t *elements) {
   return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
                                                          numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet(
+MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet(
     MlirType shapedType, MlirStringRef name, intptr_t numElements,
     const int64_t *elements) {
   return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
                                                          numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet(
+MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet(
     MlirType shapedType, MlirStringRef name, intptr_t numElements,
     const float *elements) {
   return getDenseResource<DenseF32ResourceElementsAttr>(shapedType, name,
                                                         numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute
-mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType,
-                                                MlirStringRef name,
-                                                intptr_t numElements,
-                                                const double *elements) {
+MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet(
+    MlirType shapedType, MlirStringRef name, intptr_t numElements,
+    const double *elements) {
   return getDenseResource<DenseF64ResourceElementsAttr>(shapedType, name,
                                                         numElements, elements);
 }
-MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBlobResourceElementsAttrGet(
-    MlirType shapedType, MlirStringRef name, const void *data,
-    size_t dataLength) {
-  return wrap(DenseResourceElementsAttr::get(
-      llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name),
-      UnmanagedAsmResourceBlob::allocateInferAlign(
-          llvm::ArrayRef(static_cast<const char *>(data), dataLength))));
-}
-
 template <typename U, typename T>
 static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) {
   return (*llvm::cast<U>(unwrap(attr)).tryGetAsArrayRef())[pos];
 }
 
-MLIR_CAPI_EXPORTED bool
-mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr,
+                                               intptr_t pos) {
   return getDenseResourceVal<DenseBoolResourceElementsAttr, uint8_t>(attr, pos);
 }
-MLIR_CAPI_EXPORTED uint8_t
-mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr,
+                                                   intptr_t pos) {
   return getDenseResourceVal<DenseUI8ResourceElementsAttr, uint8_t>(attr, pos);
 }
-MLIR_CAPI_EXPORTED uint16_t
-mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr,
+                                                     intptr_t pos) {
   return getDenseResourceVal<DenseUI16ResourceElementsAttr, uint16_t>(attr,
                                                                       pos);
 }
-MLIR_CAPI_EXPORTED uint32_t
-mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr,
+                                                     intptr_t pos) {
   return getDenseResourceVal<DenseUI32ResourceElementsAttr, uint32_t>(attr,
                                                                       pos);
 }
-MLIR_CAPI_EXPORTED uint64_t
-mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr,
+                                                     intptr_t pos) {
   return getDenseResourceVal<DenseUI64ResourceElementsAttr, uint64_t>(attr,
                                                                       pos);
 }
-MLIR_CAPI_EXPORTED int8_t
-mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr,
+                                                 intptr_t pos) {
   return getDenseResourceVal<DenseUI8ResourceElementsAttr, int8_t>(attr, pos);
 }
-MLIR_CAPI_EXPORTED int16_t
-mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr,
+                                                   intptr_t pos) {
   return getDenseResourceVal<DenseUI16ResourceElementsAttr, int16_t>(attr, pos);
 }
-MLIR_CAPI_EXPORTED int32_t
-mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr,
+                                                   intptr_t pos) {
   return getDenseResourceVal<DenseUI32ResourceElementsAttr, int32_t>(attr, pos);
 }
-MLIR_CAPI_EXPORTED int64_t
-mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr,
+                                                   intptr_t pos) {
   return getDenseResourceVal<DenseUI64ResourceElementsAttr, int64_t>(attr, pos);
 }
-MLIR_CAPI_EXPORTED float
-mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr,
+                                                 intptr_t pos) {
   return getDenseResourceVal<DenseF32ResourceElementsAttr, float>(attr, pos);
 }
-MLIR_CAPI_EXPORTED double
-mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
+double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr,
+                                                   intptr_t pos) {
   return getDenseResourceVal<DenseF64ResourceElementsAttr, double>(attr, pos);
 }
 

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 5d78daa296501f4..5725d05a3e132f7 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -35,6 +35,17 @@ static void registerAllUpstreamDialects(MlirContext ctx) {
   mlirDialectRegistryDestroy(registry);
 }
 
+struct ResourceDeleteUserData {
+  const char *name;
+};
+static struct ResourceDeleteUserData resourceI64BlobUserData = {
+    "resource_i64_blob"};
+static void reportResourceDelete(void *userData, const void *data, size_t size,
+                                 size_t align) {
+  fprintf(stderr, "reportResourceDelete: %s\n",
+          ((struct ResourceDeleteUserData *)userData)->name);
+}
+
 void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
                       MlirLocation location, MlirBlock funcBody) {
   MlirValue iv = mlirBlockGetArgument(loopBody, 0);
@@ -1270,10 +1281,14 @@ int printBuiltinAttributes(MlirContext ctx) {
   MlirAttribute doublesBlob = mlirUnmanagedDenseDoubleResourceElementsAttrGet(
       mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding),
       mlirStringRefCreateFromCString("resource_f64"), 2, doubles);
-  MlirAttribute blobBlob = mlirUnmanagedDenseBlobResourceElementsAttrGet(
+  MlirAttribute blobBlob = mlirUnmanagedDenseResourceElementsAttrGet(
       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
-      mlirStringRefCreateFromCString("resource_i64_blob"), uints64,
-      sizeof(uints64));
+      mlirStringRefCreateFromCString("resource_i64_blob"), /*data=*/uints64,
+      /*dataLength=*/sizeof(uints64),
+      /*dataAlignment=*/_Alignof(uint64_t),
+      /*dataIsMutable=*/false,
+      /*deleter=*/reportResourceDelete,
+      /*userData=*/(void *)&resourceI64BlobUserData);
 
   mlirAttributeDump(uint8Blob);
   mlirAttributeDump(uint16Blob);
@@ -2329,9 +2344,13 @@ int main(void) {
   if (testDialectRegistry())
     return 15;
 
-  mlirContextDestroy(ctx);
-
   testExplicitThreadPools();
   testDiagnostics();
+
+  // CHECK: DESTROY MAIN CONTEXT
+  // CHECK: reportResourceDelete: resource_i64_blob
+  fprintf(stderr, "DESTROY MAIN CONTEXT\n");
+  mlirContextDestroy(ctx);
+
   return 0;
 }

diff  --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 452d860861d783a..9251588a4c48a6e 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -5,6 +5,7 @@
 import gc
 from mlir.ir import *
 import numpy as np
+import weakref
 
 
 def run(f):
@@ -162,7 +163,7 @@ def testGetDenseElementsBF16():
 @run
 def testGetDenseElementsInteger4():
     with Context():
-        array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8)
+        array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.int8)
         attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
         # Note: These values don't mean much since just bit-casting. But they
         # shouldn't change.
@@ -417,3 +418,44 @@ def testGetDenseElementsIndex():
         print(arr)
         # CHECK: True
         print(arr.dtype == np.int64)
+
+
+# CHECK-LABEL: TEST: testGetDenseResourceElementsAttr
+ at run
+def testGetDenseResourceElementsAttr():
+    def on_delete(_):
+        print("BACKING MEMORY DELETED")
+
+    context = Context()
+    mview = memoryview(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
+    ref = weakref.ref(mview, on_delete)
+
+    def test_attribute(context, mview):
+        with context, Location.unknown():
+            element_type = IntegerType.get_signless(32)
+            tensor_type = RankedTensorType.get((2, 3), element_type)
+            resource = DenseResourceElementsAttr.get_from_buffer(
+                mview, "from_py", tensor_type
+            )
+            module = Module.parse("module {}")
+            module.operation.attributes["test.resource"] = resource
+            # CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32>
+            # CHECK: from_py: "0x04000000010000000200000003000000040000000500000006000000"
+            print(module)
+
+            # Verifies type casting.
+            # CHECK: dense_resource<from_py> : tensor<2x3xi32>
+            print(
+                DenseResourceElementsAttr(module.operation.attributes["test.resource"])
+            )
+
+    test_attribute(context, mview)
+    mview = None
+    gc.collect()
+    # CHECK: FREEING CONTEXT
+    print("FREEING CONTEXT")
+    context = None
+    gc.collect()
+    # CHECK: BACKING MEMORY DELETED
+    # CHECK: EXIT FUNCTION
+    print("EXIT FUNCTION")


        


More information about the Mlir-commits mailing list