[Mlir-commits] [mlir] 2a603de - [mlir][Python] Fix conversion of non-zero offset memrefs to np.arrays

Alex Zinenko llvmlistbot at llvm.org
Tue Sep 5 01:03:12 PDT 2023


Author: Felix Schneider
Date: 2023-09-05T08:02:59Z
New Revision: 2a603deec49cb6a27f3a29480ed8a133eef31cee

URL: https://github.com/llvm/llvm-project/commit/2a603deec49cb6a27f3a29480ed8a133eef31cee
DIFF: https://github.com/llvm/llvm-project/commit/2a603deec49cb6a27f3a29480ed8a133eef31cee.diff

LOG: [mlir][Python] Fix conversion of non-zero offset memrefs to np.arrays

Memref descriptors contain an `offset` field that denotes the start of
the content of the memref relative to the `alignedPtr`. This offset is
not considered when converting a memref descriptor to a np.array in the
Python runtime library, essentially treating all memrefs as if they had
an offset of zero. This patch introduces the necessary pointer arithmetic
to find the actual beginning of the memref contents to the memref->numpy
conversion functions.

There is an ongoing discussion about whether the `offset` field is needed
at all in the memref descriptor.
Until that is decided, the Python runtime and CRunnerUtils should
still correctly implement the offset handling.

Related: https://reviews.llvm.org/D157008

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D158494

Added: 
    

Modified: 
    mlir/python/mlir/runtime/np_to_memref.py
    mlir/test/python/execution_engine.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 51433d75ac4fb1f..0a3b411041b2f4d 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -114,13 +114,21 @@ def get_unranked_memref_descriptor(nparray):
     d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
     return d
 
+def move_aligned_ptr_by_offset(aligned_ptr, offset):
+    """Moves the supplied ctypes pointer ahead by `offset` elements."""
+    aligned_addr = ctypes.addressof(aligned_ptr.contents)
+    elem_size = ctypes.sizeof(aligned_ptr.contents)
+    shift = offset * elem_size
+    content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr))
+    return content_ptr
 
 def unranked_memref_to_numpy(unranked_memref, np_dtype):
     """Converts unranked memrefs to numpy arrays."""
     ctp = as_ctype(np_dtype)
     descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
     val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
-    np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
+    content_ptr = move_aligned_ptr_by_offset(val[0].aligned, val[0].offset)
+    np_arr = np.ctypeslib.as_array(content_ptr, shape=val[0].shape)
     strided_arr = np.lib.stride_tricks.as_strided(
         np_arr,
         np.ctypeslib.as_array(val[0].shape),
@@ -131,8 +139,9 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype):
 
 def ranked_memref_to_numpy(ranked_memref):
     """Converts ranked memrefs to numpy arrays."""
+    content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset)
     np_arr = np.ctypeslib.as_array(
-        ranked_memref[0].aligned, shape=ranked_memref[0].shape
+        content_ptr, shape=ranked_memref[0].shape
     )
     strided_arr = np.lib.stride_tricks.as_strided(
         np_arr,

diff  --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 50d6e82348a9f6f..e8b47007a8907dd 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -245,6 +245,87 @@ def callback(a):
 run(testRankedMemRefCallback)
 
 
+# Test callback with a ranked memref with non-zero offset.
+# CHECK-LABEL: TEST: testRankedMemRefWithOffsetCallback
+def testRankedMemRefWithOffsetCallback():
+    # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
+    @ctypes.CFUNCTYPE(
+        None,
+        ctypes.POINTER(
+            make_nd_memref_descriptor(1, np.ctypeslib.as_ctypes_type(np.float32))
+        ),
+    )
+    def callback(a):
+        arr = ranked_memref_to_numpy(a)
+        log("Inside Callback: ")
+        log(arr)
+
+    with Context():
+        # The module takes a subview of the argument memref and calls the callback with it
+        module = Module.parse(
+            r"""
+func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
+  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
+  %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
+  %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[?], offset: ?>>
+  call @some_callback_into_python(%cast) : (memref<?xf32, strided<[?], offset: ?>>) -> ()
+  return
+}
+func.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
+"""
+        )
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.register_runtime("some_callback_into_python", callback)
+        inp_arr = np.array([0, 0, 0, 1, 2], np.float32)
+        # CHECK: Inside Callback:
+        # CHECK{LITERAL}: [1. 2.]
+        execution_engine.invoke(
+            "callback_memref",
+            ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
+        )
+
+
+run(testRankedMemRefWithOffsetCallback)
+
+
+# Test callback with an unranked memref with non-zero offset
+# CHECK-LABEL: TEST: testUnrankedMemRefWithOffsetCallback
+def testUnrankedMemRefWithOffsetCallback():
+    # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
+    @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
+    def callback(a):
+        arr = unranked_memref_to_numpy(a, np.float32)
+        log("Inside callback: ")
+        log(arr)
+
+    with Context():
+        # The module takes a subview of the argument memref, casts it to an unranked memref and 
+        # calls the callback with it.
+        module = Module.parse(
+            r"""
+func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
+    %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
+    %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
+    %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<*xf32>
+    call @some_callback_into_python(%cast) : (memref<*xf32>) -> ()
+    return
+}
+func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface}
+"""
+        )
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.register_runtime("some_callback_into_python", callback)
+        inp_arr = np.array([1, 2, 3, 4, 5], np.float32)
+        # CHECK: Inside callback:
+        # CHECK{LITERAL}: [4. 5.]
+        execution_engine.invoke(
+            "callback_memref",
+            ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
+        )
+
+run(testUnrankedMemRefWithOffsetCallback)
+
+
 #  Test addition of two memrefs.
 # CHECK-LABEL: TEST: testMemrefAdd
 def testMemrefAdd():


        


More information about the Mlir-commits mailing list