[Mlir-commits] [mlir] 102fd1c - Add support for numpy arrays to memref conversions.
Mehdi Amini
llvmlistbot at llvm.org
Thu Apr 15 16:41:38 PDT 2021
Author: Prashant Kumar
Date: 2021-04-15T23:41:26Z
New Revision: 102fd1cb8b40dd6096850fbb9105ffb98aa31824
URL: https://github.com/llvm/llvm-project/commit/102fd1cb8b40dd6096850fbb9105ffb98aa31824
DIFF: https://github.com/llvm/llvm-project/commit/102fd1cb8b40dd6096850fbb9105ffb98aa31824.diff
LOG: Add support for numpy arrays to memref conversions.
This offers the ability to pass numpy arrays to the corresponding
memref argument.
Reviewed By: mehdi_amini, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D100077
Added:
mlir/lib/Bindings/Python/mlir/runtime/__init__.py
mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py
Modified:
mlir/test/Bindings/Python/execution_engine.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/mlir/runtime/__init__.py b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py
new file mode 100644
index 000000000000..8a28fd935a40
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py
@@ -0,0 +1 @@
+from .np_to_memref import *
diff --git a/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py b/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py
new file mode 100644
index 000000000000..43ef9543528c
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py
@@ -0,0 +1,119 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa.
+
+import numpy as np
+import ctypes
+
+
+def make_nd_memref_descriptor(rank, dtype):
+ class MemRefDescriptor(ctypes.Structure):
+ """
+ Build an empty descriptor for the given rank/dtype, where rank>0.
+ """
+
+ _fields_ = [
+ ("allocated", ctypes.c_longlong),
+ ("aligned", ctypes.POINTER(dtype)),
+ ("offset", ctypes.c_longlong),
+ ("shape", ctypes.c_longlong * rank),
+ ("strides", ctypes.c_longlong * rank),
+ ]
+
+ return MemRefDescriptor
+
+
+def make_zero_d_memref_descriptor(dtype):
+ class MemRefDescriptor(ctypes.Structure):
+ """
+ Build an empty descriptor for the given dtype, where rank=0.
+ """
+
+ _fields_ = [
+ ("allocated", ctypes.c_longlong),
+ ("aligned", ctypes.POINTER(dtype)),
+ ("offset", ctypes.c_longlong),
+ ]
+
+ return MemRefDescriptor
+
+
+class UnrankedMemRefDescriptor(ctypes.Structure):
+ """ Creates a ctype struct for memref descriptor"""
+
+ _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
+
+
+def get_ranked_memref_descriptor(nparray):
+ """
+ Return a ranked memref descriptor for the given numpy array.
+ """
+ if nparray.ndim == 0:
+ x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))()
+ x.allocated = nparray.ctypes.data
+ x.aligned = nparray.ctypes.data_as(
+ ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
+ )
+ x.offset = ctypes.c_longlong(0)
+ return x
+
+ x = make_nd_memref_descriptor(
+ nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype)
+ )()
+ x.allocated = nparray.ctypes.data
+ x.aligned = nparray.ctypes.data_as(
+ ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
+ )
+ x.offset = ctypes.c_longlong(0)
+ x.shape = nparray.ctypes.shape
+
+ # Numpy uses byte quantities to express strides, MLIR OTOH uses the
+ # torch abstraction which specifies strides in terms of elements.
+ strides_ctype_t = ctypes.c_longlong * nparray.ndim
+ x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
+ return x
+
+
+def get_unranked_memref_descriptor(nparray):
+ """
+ Return a generic/unranked memref descriptor for the given numpy array.
+ """
+ d = UnrankedMemRefDescriptor()
+ d.rank = nparray.ndim
+ x = get_ranked_memref_descriptor(nparray)
+ d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
+ return d
+
+
+def unranked_memref_to_numpy(unranked_memref, np_dtype):
+ """
+ Converts unranked memrefs to numpy arrays.
+ """
+ descriptor = make_nd_memref_descriptor(
+ unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype)
+ )
+ val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
+ np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
+ strided_arr = np.lib.stride_tricks.as_strided(
+ np_arr,
+ np.ctypeslib.as_array(val[0].shape),
+ np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
+ )
+ return strided_arr
+
+
+def ranked_memref_to_numpy(ranked_memref):
+ """
+ Converts ranked memrefs to numpy arrays.
+ """
+ np_arr = np.ctypeslib.as_array(
+ ranked_memref[0].aligned, shape=ranked_memref[0].shape
+ )
+ strided_arr = np.lib.stride_tricks.as_strided(
+ np_arr,
+ np.ctypeslib.as_array(ranked_memref[0].shape),
+ np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
+ )
+ return strided_arr
diff --git a/mlir/test/Bindings/Python/execution_engine.py b/mlir/test/Bindings/Python/execution_engine.py
index 9ef4dceea2a8..72a6efe22ca2 100644
--- a/mlir/test/Bindings/Python/execution_engine.py
+++ b/mlir/test/Bindings/Python/execution_engine.py
@@ -4,6 +4,7 @@
from mlir.ir import *
from mlir.passmanager import *
from mlir.execution_engine import *
+from mlir.runtime import *
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
@@ -131,3 +132,179 @@ def callback(a, b):
log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2))
run(testBasicCallback)
+
+# Test callback with an unranked memref
+# CHECK-LABEL: TEST: testUnrankedMemRefCallback
+def testUnrankedMemRefCallback():
+ # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
+ @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
+ def callback(a):
+ arr = unranked_memref_to_numpy(a, np.float32)
+ log("Inside callback: ")
+ log(arr)
+
+ with Context():
+ # The module just forwards to a runtime function known as "some_callback_into_python".
+ module = Module.parse(
+ r"""
+func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
+ call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
+ return
+}
+func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
+"""
+ )
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.register_runtime("some_callback_into_python", callback)
+ inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
+ # CHECK: Inside callback:
+ # CHECK{LITERAL}: [[1. 2.]
+ # CHECK{LITERAL}: [3. 4.]]
+ execution_engine.invoke(
+ "callback_memref",
+ ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
+ )
+ inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
+ strided_arr = np.lib.stride_tricks.as_strided(
+ inp_arr_1, strides=(4, 0), shape=(3, 4)
+ )
+ # CHECK: Inside callback:
+ # CHECK{LITERAL}: [[5. 5. 5. 5.]
+ # CHECK{LITERAL}: [6. 6. 6. 6.]
+ # CHECK{LITERAL}: [7. 7. 7. 7.]]
+ execution_engine.invoke(
+ "callback_memref",
+ ctypes.pointer(
+ ctypes.pointer(get_unranked_memref_descriptor(strided_arr))
+ ),
+ )
+
+run(testUnrankedMemRefCallback)
+
+# Test callback with a ranked memref.
+# CHECK-LABEL: TEST: testRankedMemRefCallback
+def testRankedMemRefCallback():
+ # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
+ @ctypes.CFUNCTYPE(
+ None,
+ ctypes.POINTER(
+ make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
+ ),
+ )
+ def callback(a):
+ arr = ranked_memref_to_numpy(a)
+ log("Inside Callback: ")
+ log(arr)
+
+ with Context():
+ # The module just forwards to a runtime function known as "some_callback_into_python".
+ module = Module.parse(
+ r"""
+func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
+ call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
+ return
+}
+func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
+"""
+ )
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.register_runtime("some_callback_into_python", callback)
+ inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
+ # CHECK: Inside Callback:
+ # CHECK{LITERAL}: [[1. 5.]
+ # CHECK{LITERAL}: [6. 7.]]
+ execution_engine.invoke(
+ "callback_memref", ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr)))
+ )
+
+run(testRankedMemRefCallback)
+
+# Test addition of two memref
+# CHECK-LABEL: TEST: testMemrefAdd
+def testMemrefAdd():
+ with Context():
+ module = Module.parse(
+ """
+ module {
+ func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
+ %0 = constant 0 : index
+ %1 = memref.load %arg0[%0] : memref<1xf32>
+ %2 = memref.load %arg1[] : memref<f32>
+ %3 = addf %1, %2 : f32
+ memref.store %3, %arg2[%0] : memref<1xf32>
+ return
+ }
+ } """
+ )
+ arg1 = np.array([32.5]).astype(np.float32)
+ arg2 = np.array(6).astype(np.float32)
+ res = np.array([0]).astype(np.float32)
+
+ arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
+ arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
+ res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke(
+ "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
+ )
+ # CHECK: [32.5] + 6.0 = [38.5]
+ log("{0} + {1} = {2}".format(arg1, arg2, res))
+
+run(testMemrefAdd)
+
+# Test addition of two 2d_memref
+# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
+def testDynamicMemrefAdd2D():
+ with Context():
+ module = Module.parse(
+ """
+ module {
+ func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c1 = constant 1 : index
+ br ^bb1(%c0 : index)
+ ^bb1(%0: index): // 2 preds: ^bb0, ^bb5
+ %1 = cmpi slt, %0, %c2 : index
+ cond_br %1, ^bb2, ^bb6
+ ^bb2: // pred: ^bb1
+ %c0_0 = constant 0 : index
+ %c2_1 = constant 2 : index
+ %c1_2 = constant 1 : index
+ br ^bb3(%c0_0 : index)
+ ^bb3(%2: index): // 2 preds: ^bb2, ^bb4
+ %3 = cmpi slt, %2, %c2_1 : index
+ cond_br %3, ^bb4, ^bb5
+ ^bb4: // pred: ^bb3
+ %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
+ %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
+ %6 = addf %4, %5 : f32
+ memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
+ %7 = addi %2, %c1_2 : index
+ br ^bb3(%7 : index)
+ ^bb5: // pred: ^bb3
+ %8 = addi %0, %c1 : index
+ br ^bb1(%8 : index)
+ ^bb6: // pred: ^bb1
+ return
+ }
+ }
+ """
+ )
+ arg1 = np.random.randn(2,2).astype(np.float32)
+ arg2 = np.random.randn(2,2).astype(np.float32)
+ res = np.random.randn(2,2).astype(np.float32)
+
+ arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
+ arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
+ res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke(
+ "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
+ )
+ # CHECK: True
+ log(np.allclose(arg1+arg2, res))
+
+run(testDynamicMemrefAdd2D)
More information about the Mlir-commits
mailing list