[Mlir-commits] [mlir] 102fd1c - Add support for numpy arrays to memref conversions.

Mehdi Amini llvmlistbot at llvm.org
Thu Apr 15 16:41:38 PDT 2021


Author: Prashant Kumar
Date: 2021-04-15T23:41:26Z
New Revision: 102fd1cb8b40dd6096850fbb9105ffb98aa31824

URL: https://github.com/llvm/llvm-project/commit/102fd1cb8b40dd6096850fbb9105ffb98aa31824
DIFF: https://github.com/llvm/llvm-project/commit/102fd1cb8b40dd6096850fbb9105ffb98aa31824.diff

LOG: Add support for numpy arrays to memref conversions.

This offers the ability to pass numpy arrays to the corresponding
memref argument.

Reviewed By: mehdi_amini, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D100077

Added: 
    mlir/lib/Bindings/Python/mlir/runtime/__init__.py
    mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py

Modified: 
    mlir/test/Bindings/Python/execution_engine.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/mlir/runtime/__init__.py b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py
new file mode 100644
index 000000000000..8a28fd935a40
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py
@@ -0,0 +1 @@
+from .np_to_memref import *

diff  --git a/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py b/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py
new file mode 100644
index 000000000000..43ef9543528c
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py
@@ -0,0 +1,119 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#  This file contains functions to convert between Memrefs and NumPy arrays and vice-versa.
+
+import numpy as np
+import ctypes
+
+
+def make_nd_memref_descriptor(rank, dtype):
+    class MemRefDescriptor(ctypes.Structure):
+        """
+        Build an empty descriptor for the given rank/dtype, where rank>0.
+        """
+
+        _fields_ = [
+            ("allocated", ctypes.c_longlong),
+            ("aligned", ctypes.POINTER(dtype)),
+            ("offset", ctypes.c_longlong),
+            ("shape", ctypes.c_longlong * rank),
+            ("strides", ctypes.c_longlong * rank),
+        ]
+
+    return MemRefDescriptor
+
+
+def make_zero_d_memref_descriptor(dtype):
+    class MemRefDescriptor(ctypes.Structure):
+        """
+        Build an empty descriptor for the given dtype, where rank=0.
+        """
+
+        _fields_ = [
+            ("allocated", ctypes.c_longlong),
+            ("aligned", ctypes.POINTER(dtype)),
+            ("offset", ctypes.c_longlong),
+        ]
+
+    return MemRefDescriptor
+
+
+class UnrankedMemRefDescriptor(ctypes.Structure):
+    """ Creates a ctype struct for memref descriptor"""
+
+    _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
+
+
+def get_ranked_memref_descriptor(nparray):
+    """
+    Return a ranked memref descriptor for the given numpy array.
+    """
+    if nparray.ndim == 0:
+        x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))()
+        x.allocated = nparray.ctypes.data
+        x.aligned = nparray.ctypes.data_as(
+            ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
+        )
+        x.offset = ctypes.c_longlong(0)
+        return x
+
+    x = make_nd_memref_descriptor(
+        nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype)
+    )()
+    x.allocated = nparray.ctypes.data
+    x.aligned = nparray.ctypes.data_as(
+        ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
+    )
+    x.offset = ctypes.c_longlong(0)
+    x.shape = nparray.ctypes.shape
+
+    # Numpy uses byte quantities to express strides, MLIR OTOH uses the
+    # torch abstraction which specifies strides in terms of elements.
+    strides_ctype_t = ctypes.c_longlong * nparray.ndim
+    x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
+    return x
+
+
+def get_unranked_memref_descriptor(nparray):
+    """
+    Return a generic/unranked memref descriptor for the given numpy array.
+    """
+    d = UnrankedMemRefDescriptor()
+    d.rank = nparray.ndim
+    x = get_ranked_memref_descriptor(nparray)
+    d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
+    return d
+
+
+def unranked_memref_to_numpy(unranked_memref, np_dtype):
+    """
+    Converts unranked memrefs to numpy arrays.
+    """
+    descriptor = make_nd_memref_descriptor(
+        unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype)
+    )
+    val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
+    np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
+    strided_arr = np.lib.stride_tricks.as_strided(
+        np_arr,
+        np.ctypeslib.as_array(val[0].shape),
+        np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
+    )
+    return strided_arr
+
+
+def ranked_memref_to_numpy(ranked_memref):
+    """
+    Converts ranked memrefs to numpy arrays.
+    """
+    np_arr = np.ctypeslib.as_array(
+        ranked_memref[0].aligned, shape=ranked_memref[0].shape
+    )
+    strided_arr = np.lib.stride_tricks.as_strided(
+        np_arr,
+        np.ctypeslib.as_array(ranked_memref[0].shape),
+        np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
+    )
+    return strided_arr

diff  --git a/mlir/test/Bindings/Python/execution_engine.py b/mlir/test/Bindings/Python/execution_engine.py
index 9ef4dceea2a8..72a6efe22ca2 100644
--- a/mlir/test/Bindings/Python/execution_engine.py
+++ b/mlir/test/Bindings/Python/execution_engine.py
@@ -4,6 +4,7 @@
 from mlir.ir import *
 from mlir.passmanager import *
 from mlir.execution_engine import *
+from mlir.runtime import *
 
 # Log everything to stderr and flush so that we have a unified stream to match
 # errors/info emitted by MLIR to stderr.
@@ -131,3 +132,179 @@ def callback(a, b):
     log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2))
 
 run(testBasicCallback)
+
+# Test callback with an unranked memref
+# CHECK-LABEL: TEST: testUnrankedMemRefCallback
+def testUnrankedMemRefCallback():
+    # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
+    @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
+    def callback(a):
+        arr = unranked_memref_to_numpy(a, np.float32)
+        log("Inside callback: ")
+        log(arr)
+
+    with Context():
+        # The module just forwards to a runtime function known as "some_callback_into_python".
+        module = Module.parse(
+            r"""
+func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
+  call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
+  return
+}
+func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
+"""
+        )
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.register_runtime("some_callback_into_python", callback)
+        inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
+        # CHECK: Inside callback: 
+        # CHECK{LITERAL}: [[1. 2.]
+        # CHECK{LITERAL}:  [3. 4.]]
+        execution_engine.invoke(
+            "callback_memref",
+            ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
+        )
+        inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
+        strided_arr = np.lib.stride_tricks.as_strided(
+            inp_arr_1, strides=(4, 0), shape=(3, 4)
+        )
+        # CHECK: Inside callback: 
+        # CHECK{LITERAL}: [[5. 5. 5. 5.]
+        # CHECK{LITERAL}:  [6. 6. 6. 6.]
+        # CHECK{LITERAL}:  [7. 7. 7. 7.]]
+        execution_engine.invoke(
+            "callback_memref",
+            ctypes.pointer(
+                ctypes.pointer(get_unranked_memref_descriptor(strided_arr))
+            ),
+        )
+
+run(testUnrankedMemRefCallback)
+
+# Test callback with a ranked memref.
+# CHECK-LABEL: TEST: testRankedMemRefCallback
+def testRankedMemRefCallback():
+    # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
+    @ctypes.CFUNCTYPE(
+        None,
+        ctypes.POINTER(
+            make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
+        ),
+    )
+    def callback(a):
+        arr = ranked_memref_to_numpy(a)
+        log("Inside Callback: ")
+        log(arr)
+
+    with Context():
+        # The module just forwards to a runtime function known as "some_callback_into_python".
+        module = Module.parse(
+            r"""
+func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
+  call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
+  return
+}
+func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
+"""
+        )
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.register_runtime("some_callback_into_python", callback)
+        inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
+        # CHECK: Inside Callback:
+        # CHECK{LITERAL}: [[1. 5.]
+        # CHECK{LITERAL}:  [6. 7.]]
+        execution_engine.invoke(
+            "callback_memref", ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr)))
+        )
+
+run(testRankedMemRefCallback)
+
+#  Test addition of two memref
+# CHECK-LABEL: TEST: testMemrefAdd
+def testMemrefAdd():
+    with Context():
+        module = Module.parse(
+            """
+      module  {
+      func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
+        %0 = constant 0 : index
+        %1 = memref.load %arg0[%0] : memref<1xf32>
+        %2 = memref.load %arg1[] : memref<f32>
+        %3 = addf %1, %2 : f32
+        memref.store %3, %arg2[%0] : memref<1xf32>
+        return
+      }
+     } """
+        )
+        arg1 = np.array([32.5]).astype(np.float32)
+        arg2 = np.array(6).astype(np.float32)
+        res = np.array([0]).astype(np.float32)
+
+        arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
+        arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
+        res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke(
+            "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
+        )
+        # CHECK: [32.5] + 6.0 = [38.5]
+        log("{0} + {1} = {2}".format(arg1, arg2, res))
+
+run(testMemrefAdd)
+
+#  Test addition of two 2d_memref
+# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
+def testDynamicMemrefAdd2D():
+    with Context():
+        module = Module.parse(
+      """
+      module  {
+        func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
+          %c0 = constant 0 : index
+          %c2 = constant 2 : index
+          %c1 = constant 1 : index
+          br ^bb1(%c0 : index)
+        ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
+          %1 = cmpi slt, %0, %c2 : index
+          cond_br %1, ^bb2, ^bb6
+        ^bb2:  // pred: ^bb1
+          %c0_0 = constant 0 : index
+          %c2_1 = constant 2 : index
+          %c1_2 = constant 1 : index
+          br ^bb3(%c0_0 : index)
+        ^bb3(%2: index):  // 2 preds: ^bb2, ^bb4
+          %3 = cmpi slt, %2, %c2_1 : index
+          cond_br %3, ^bb4, ^bb5
+        ^bb4:  // pred: ^bb3
+          %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
+          %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
+          %6 = addf %4, %5 : f32
+          memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
+          %7 = addi %2, %c1_2 : index
+          br ^bb3(%7 : index)
+        ^bb5:  // pred: ^bb3
+          %8 = addi %0, %c1 : index
+          br ^bb1(%8 : index)
+        ^bb6:  // pred: ^bb1
+          return
+        }
+      }
+        """
+        )
+        arg1 = np.random.randn(2,2).astype(np.float32)
+        arg2 = np.random.randn(2,2).astype(np.float32)
+        res = np.random.randn(2,2).astype(np.float32)
+
+        arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
+        arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
+        res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke(
+            "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
+        )
+        # CHECK: True
+        log(np.allclose(arg1+arg2, res))
+
+run(testDynamicMemrefAdd2D)


        


More information about the Mlir-commits mailing list