[Mlir-commits] [mlir] d668218 - [mlir][python][ctypes] fix ctype python binding complication for complex

Aart Bik llvmlistbot at llvm.org
Wed Jun 1 10:15:33 PDT 2022


Author: Aart Bik
Date: 2022-06-01T10:15:24-07:00
New Revision: d668218946b1b7e0f3f376c59fe818bc0579d6fd

URL: https://github.com/llvm/llvm-project/commit/d668218946b1b7e0f3f376c59fe818bc0579d6fd
DIFF: https://github.com/llvm/llvm-project/commit/d668218946b1b7e0f3f376c59fe818bc0579d6fd.diff

LOG: [mlir][python][ctypes] fix ctype python binding complication for complex

There is no direct ctypes for MLIR's complex (and thus np.complex128
and np.complex64) yet, causing the mlir python binding methods for
memrefs to crash. This revision fixes this by passing complex arrays
as tuples of floats, correcting at the boundaries for the proper view.

NOTE: some of these changes (4 -> 2) were forced by the new "linting"

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D126422

Added: 
    

Modified: 
    mlir/python/mlir/runtime/np_to_memref.py
    mlir/test/python/execution_engine.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 43ef9543528c3..de5b8d6f70d8b 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -8,112 +8,121 @@
 import ctypes
 
 
+class C128(ctypes.Structure):
+  """A ctype representation for MLIR's Double Complex."""
+  _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
+
+
+class C64(ctypes.Structure):
+  """A ctype representation for MLIR's Float Complex."""
+  _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
+
+
+def as_ctype(dtp):
+  """Converts dtype to ctype."""
+  if dtp is np.dtype(np.complex128):
+    return C128
+  if dtp is np.dtype(np.complex64):
+    return C64
+  return np.ctypeslib.as_ctypes_type(dtp)
+
+
 def make_nd_memref_descriptor(rank, dtype):
-    class MemRefDescriptor(ctypes.Structure):
-        """
-        Build an empty descriptor for the given rank/dtype, where rank>0.
-        """
 
-        _fields_ = [
-            ("allocated", ctypes.c_longlong),
-            ("aligned", ctypes.POINTER(dtype)),
-            ("offset", ctypes.c_longlong),
-            ("shape", ctypes.c_longlong * rank),
-            ("strides", ctypes.c_longlong * rank),
-        ]
+  class MemRefDescriptor(ctypes.Structure):
+    """Builds an empty descriptor for the given rank/dtype, where rank>0."""
 
-    return MemRefDescriptor
+    _fields_ = [
+        ("allocated", ctypes.c_longlong),
+        ("aligned", ctypes.POINTER(dtype)),
+        ("offset", ctypes.c_longlong),
+        ("shape", ctypes.c_longlong * rank),
+        ("strides", ctypes.c_longlong * rank),
+    ]
+
+  return MemRefDescriptor
 
 
 def make_zero_d_memref_descriptor(dtype):
-    class MemRefDescriptor(ctypes.Structure):
-        """
-        Build an empty descriptor for the given dtype, where rank=0.
-        """
 
-        _fields_ = [
-            ("allocated", ctypes.c_longlong),
-            ("aligned", ctypes.POINTER(dtype)),
-            ("offset", ctypes.c_longlong),
-        ]
+  class MemRefDescriptor(ctypes.Structure):
+    """Builds an empty descriptor for the given dtype, where rank=0."""
 
-    return MemRefDescriptor
+    _fields_ = [
+        ("allocated", ctypes.c_longlong),
+        ("aligned", ctypes.POINTER(dtype)),
+        ("offset", ctypes.c_longlong),
+    ]
 
+  return MemRefDescriptor
 
-class UnrankedMemRefDescriptor(ctypes.Structure):
-    """ Creates a ctype struct for memref descriptor"""
 
-    _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
+class UnrankedMemRefDescriptor(ctypes.Structure):
+  """Creates a ctype struct for memref descriptor"""
+  _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
 
 
 def get_ranked_memref_descriptor(nparray):
-    """
-    Return a ranked memref descriptor for the given numpy array.
-    """
-    if nparray.ndim == 0:
-        x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))()
-        x.allocated = nparray.ctypes.data
-        x.aligned = nparray.ctypes.data_as(
-            ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
-        )
-        x.offset = ctypes.c_longlong(0)
-        return x
-
-    x = make_nd_memref_descriptor(
-        nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype)
-    )()
+  """Returns a ranked memref descriptor for the given numpy array."""
+  ctp = as_ctype(nparray.dtype)
+  if nparray.ndim == 0:
+    x = make_zero_d_memref_descriptor(ctp)()
     x.allocated = nparray.ctypes.data
-    x.aligned = nparray.ctypes.data_as(
-        ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
-    )
+    x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
     x.offset = ctypes.c_longlong(0)
-    x.shape = nparray.ctypes.shape
-
-    # Numpy uses byte quantities to express strides, MLIR OTOH uses the
-    # torch abstraction which specifies strides in terms of elements.
-    strides_ctype_t = ctypes.c_longlong * nparray.ndim
-    x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
     return x
 
+  x = make_nd_memref_descriptor(nparray.ndim, ctp)()
+  x.allocated = nparray.ctypes.data
+  x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
+  x.offset = ctypes.c_longlong(0)
+  x.shape = nparray.ctypes.shape
+
+  # Numpy uses byte quantities to express strides, MLIR OTOH uses the
+  # torch abstraction which specifies strides in terms of elements.
+  strides_ctype_t = ctypes.c_longlong * nparray.ndim
+  x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
+  return x
+
 
 def get_unranked_memref_descriptor(nparray):
-    """
-    Return a generic/unranked memref descriptor for the given numpy array.
-    """
-    d = UnrankedMemRefDescriptor()
-    d.rank = nparray.ndim
-    x = get_ranked_memref_descriptor(nparray)
-    d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
-    return d
+  """Returns a generic/unranked memref descriptor for the given numpy array."""
+  d = UnrankedMemRefDescriptor()
+  d.rank = nparray.ndim
+  x = get_ranked_memref_descriptor(nparray)
+  d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
+  return d
 
 
 def unranked_memref_to_numpy(unranked_memref, np_dtype):
-    """
-    Converts unranked memrefs to numpy arrays.
-    """
-    descriptor = make_nd_memref_descriptor(
-        unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype)
-    )
-    val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
-    np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
-    strided_arr = np.lib.stride_tricks.as_strided(
-        np_arr,
-        np.ctypeslib.as_array(val[0].shape),
-        np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
-    )
-    return strided_arr
+  """Converts unranked memrefs to numpy arrays."""
+  ctp = as_ctype(np_dtype)
+  descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
+  val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
+  np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
+  strided_arr = np.lib.stride_tricks.as_strided(
+      np_arr,
+      np.ctypeslib.as_array(val[0].shape),
+      np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
+  )
+  if strided_arr.dtype == C128:
+    return strided_arr.view("complex128")
+  if strided_arr.dtype == C64:
+    return strided_arr.view("complex64")
+  return strided_arr
 
 
 def ranked_memref_to_numpy(ranked_memref):
-    """
-    Converts ranked memrefs to numpy arrays.
-    """
-    np_arr = np.ctypeslib.as_array(
-        ranked_memref[0].aligned, shape=ranked_memref[0].shape
-    )
-    strided_arr = np.lib.stride_tricks.as_strided(
-        np_arr,
-        np.ctypeslib.as_array(ranked_memref[0].shape),
-        np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
-    )
-    return strided_arr
+  """Converts ranked memrefs to numpy arrays."""
+  np_arr = np.ctypeslib.as_array(
+      ranked_memref[0].aligned, shape=ranked_memref[0].shape)
+  strided_arr = np.lib.stride_tricks.as_strided(
+      np_arr,
+      np.ctypeslib.as_array(ranked_memref[0].shape),
+      np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
+  )
+  if strided_arr.dtype == C128:
+    return strided_arr.view("complex128")
+  if strided_arr.dtype == C64:
+    return strided_arr.view("complex64")
+  return strided_arr

diff  --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 5eb0dffdc8456..53cbac35482e7 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -64,7 +64,7 @@ def testInvalidModule():
 def lowerToLLVM(module):
   import mlir.conversions
   pm = PassManager.parse(
-      "convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts")
+      "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts")
   pm.run(module)
   return module
 
@@ -266,6 +266,102 @@ def testMemrefAdd():
 run(testMemrefAdd)
 
 
+# Test addition of two complex memrefs
+# CHECK-LABEL: TEST: testComplexMemrefAdd
+def testComplexMemrefAdd():
+  with Context():
+    module = Module.parse("""
+    module  {
+      func.func @main(%arg0: memref<1xcomplex<f64>>,
+                      %arg1: memref<1xcomplex<f64>>,
+                      %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
+        %0 = arith.constant 0 : index
+        %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>>
+        %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>>
+        %3 = complex.add %1, %2 : complex<f64>
+        memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
+        return
+      }
+    } """)
+
+    arg1 = np.array([1.+2.j]).astype(np.complex128)
+    arg2 = np.array([3.+4.j]).astype(np.complex128)
+    arg3  = np.array([0.+0.j]).astype(np.complex128)
+
+    arg1_memref_ptr = ctypes.pointer(
+        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
+    arg2_memref_ptr = ctypes.pointer(
+        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
+    arg3_memref_ptr = ctypes.pointer(
+        ctypes.pointer(get_ranked_memref_descriptor(arg3)))
+
+    execution_engine = ExecutionEngine(lowerToLLVM(module))
+    execution_engine.invoke("main",
+                            arg1_memref_ptr,
+                            arg2_memref_ptr,
+                            arg3_memref_ptr)
+    # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
+    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+    # test to-numpy utility
+    # CHECK: [4.+6.j]
+    npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+    log(npout)
+
+
+run(testComplexMemrefAdd)
+
+
+# Test addition of two complex unranked memrefs
+# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
+def testComplexUnrankedMemrefAdd():
+  with Context():
+    module = Module.parse("""
+    module  {
+      func.func @main(%arg0: memref<*xcomplex<f32>>,
+                      %arg1: memref<*xcomplex<f32>>,
+                      %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
+        %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
+        %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
+        %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
+        %0 = arith.constant 0 : index
+        %1 = memref.load %A[%0] : memref<1xcomplex<f32>>
+        %2 = memref.load %B[%0] : memref<1xcomplex<f32>>
+        %3 = complex.add %1, %2 : complex<f32>
+        memref.store %3, %C[%0] : memref<1xcomplex<f32>>
+        return
+      }
+    } """)
+
+    arg1 = np.array([5.+6.j]).astype(np.complex64)
+    arg2 = np.array([7.+8.j]).astype(np.complex64)
+    arg3  = np.array([0.+0.j]).astype(np.complex64)
+
+    arg1_memref_ptr = ctypes.pointer(
+        ctypes.pointer(get_unranked_memref_descriptor(arg1)))
+    arg2_memref_ptr = ctypes.pointer(
+        ctypes.pointer(get_unranked_memref_descriptor(arg2)))
+    arg3_memref_ptr = ctypes.pointer(
+        ctypes.pointer(get_unranked_memref_descriptor(arg3)))
+
+    execution_engine = ExecutionEngine(lowerToLLVM(module))
+    execution_engine.invoke("main",
+                            arg1_memref_ptr,
+                            arg2_memref_ptr,
+                            arg3_memref_ptr)
+    # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
+    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+    # test to-numpy utility
+    # CHECK: [12.+14.j]
+    npout = unranked_memref_to_numpy(arg3_memref_ptr[0],
+                                     np.dtype(np.complex64))
+    log(npout)
+
+
+run(testComplexUnrankedMemrefAdd)
+
+
 #  Test addition of two 2d_memref
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
 def testDynamicMemrefAdd2D():


        


More information about the Mlir-commits mailing list