[Mlir-commits] [mlir] e5e7e51 - [mlir][sparse][taco] Support complex types.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 17 16:06:57 PDT 2022


Author: bixia1
Date: 2022-06-17T16:06:53-07:00
New Revision: e5e7e5147322d4cbfd0c8309893c2273c3ee41ac

URL: https://github.com/llvm/llvm-project/commit/e5e7e5147322d4cbfd0c8309893c2273c3ee41ac
DIFF: https://github.com/llvm/llvm-project/commit/e5e7e5147322d4cbfd0c8309893c2273c3ee41ac.diff

LOG: [mlir][sparse][taco] Support complex types.

Support complex types of float and double. See the added test for an example.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D128076

Added: 
    mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py

Modified: 
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py

Removed: 
    


################################################################################
diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
new file mode 100644
index 0000000000000..723d39f9700fa
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
@@ -0,0 +1,31 @@
+# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
+import numpy as np
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import mlir_pytaco_api as pt
+
+compressed = pt.compressed
+
+passed = 0
+all_types = [pt.complex64, pt.complex128]
+for t in all_types:
+  i, j = pt.get_index_vars(2)
+  A = pt.tensor([2, 3], dtype=t)
+  B = pt.tensor([2, 3], dtype=t)
+  C = pt.tensor([2, 3], compressed, dtype=t)
+  A.insert([0, 1], 10 + 20j)
+  A.insert([1, 2], 40 + 0.5j)
+  B.insert([0, 0], 20)
+  B.insert([1, 2], 30 + 15j)
+  C[i, j] = A[i, j] + B[i, j]
+
+  indices, values = C.get_coordinates_and_values()
+  passed += isinstance(values[0], t.value)
+  passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+  passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
+
+# CHECK: Number of passed: 6
+print("Number of passed:", passed)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 9bab366fcbe82..48b6b552bb110 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -75,16 +75,20 @@ class Type(enum.Enum):
   # numpy _ctype_from_dtype_scalar can't handle np.float16 yet.
   FLOAT32 = np.float32
   FLOAT64 = np.float64
+  COMPLEX64 = np.complex64
+  COMPLEX128 = np.complex128
 
 
 # All floating point type enums.
 _FLOAT_TYPES = (Type.FLOAT32, Type.FLOAT64)
 # All integral type enums.
 _INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64)
+# All complex type enums.
+_COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128)
 # Type alias for any numpy type used to implement the runtime support for the
 # enum data types.
 _AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float32,
-                        np.float64]
+                        np.float64, np.complex64, np.complex128]
 
 
 @dataclasses.dataclass(frozen=True)
@@ -111,6 +115,10 @@ def is_int(self) -> bool:
     """Returns whether the data type represents an integral value."""
     return self.kind in _INT_TYPES
 
+  def is_complex(self) -> bool:
+    """Returns whether the data type represents a complex value."""
+    return self.kind in _COMPLEX_TYPES
+
   @property
   def value(self) -> _AnyRuntimeType:
     """Returns the numpy dtype for the data type."""
@@ -125,7 +133,9 @@ def _dtype_to_mlir_str(dtype: DType) -> str:
       Type.INT32: "i32",
       Type.INT64: "i64",
       Type.FLOAT32: "f32",
-      Type.FLOAT64: "f64"
+      Type.FLOAT64: "f64",
+      Type.COMPLEX64: "complex<f32>",
+      Type.COMPLEX128: "complex<f64>"
   }
   return dtype_to_str[dtype.kind]
 
@@ -138,7 +148,9 @@ def _nptype_to_taco_type(ty: np.dtype) -> DType:
       np.int32: Type.INT32,
       np.int64: Type.INT64,
       np.float32: Type.FLOAT32,
-      np.float64: Type.FLOAT64
+      np.float64: Type.FLOAT64,
+      np.complex64: Type.COMPLEX64,
+      np.complex128: Type.COMPLEX128
   }
   return DType(nptype_to_dtype[ty])
 
@@ -151,7 +163,9 @@ def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
       Type.INT32: ir.IntegerType.get_signless(32),
       Type.INT64: ir.IntegerType.get_signless(64),
       Type.FLOAT32: ir.F32Type.get(),
-      Type.FLOAT64: ir.F64Type.get()
+      Type.FLOAT64: ir.F64Type.get(),
+      Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),
+      Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get())
   }
   return dtype_to_irtype[dtype.kind]
 
@@ -1004,8 +1018,8 @@ def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat],
       raise ValueError(f"Invalid format argument: {fmt}.")
 
   def __init__(self,
-               value_or_shape: Optional[Union[List[int], Tuple[int, ...], float,
-                                              int]] = None,
+               value_or_shape: Optional[Union[List[int], Tuple[int, ...],
+                                              complex, float, int]] = None,
                fmt: Optional[Union[ModeFormat, List[ModeFormat],
                                    Format]] = None,
                dtype: Optional[DType] = None,
@@ -1059,7 +1073,7 @@ def __init__(self,
     self._values = []
     self._stats = _Stats()
     if value_or_shape is None or isinstance(value_or_shape, int) or isinstance(
-        value_or_shape, float):
+        value_or_shape, float) or isinstance(value_or_shape, complex):
       # Create a scalar tensor and ignore the fmt parameter.
       self._shape = []
       self._format = _make_format([], [])
@@ -1108,7 +1122,7 @@ def __repr__(self) -> str:
     return (f"Tensor(_name={repr(self._name)} "
             f"_dtype={repr(self._dtype)} : ") + value_str
 
-  def insert(self, coords: List[int], val: Union[float, int]) -> None:
+  def insert(self, coords: List[int], val: Union[complex, float, int]) -> None:
     """Inserts a value to the given coordinate.
 
     Args:
@@ -1134,7 +1148,8 @@ def insert(self, coords: List[int], val: Union[float, int]) -> None:
       raise ValueError("Invalid coordinate for rank: "
                        f"{self.order}, {coords}.")
 
-    if not isinstance(val, int) and not isinstance(val, float):
+    if not isinstance(val, int) and not isinstance(
+        val, float) and not isinstance(val, complex):
       raise ValueError(f"Value is neither int nor float: {val}.")
 
     self._coords.append(tuple(coords))

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
index 7573d6655360f..8300dfef5bc63 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
@@ -41,6 +41,8 @@
 int64 = mlir_pytaco.DType(mlir_pytaco.Type.INT64)
 float32 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32)
 float64 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64)
+complex64 = mlir_pytaco.DType(mlir_pytaco.Type.COMPLEX64)
+complex128 = mlir_pytaco.DType(mlir_pytaco.Type.COMPLEX128)
 
 # Storage format constants defined by the PyTACO API. In PyTACO, each storage
 # format constant has two aliasing names.

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
index ce6d3c70bd50c..f5ec14aa80b03 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
@@ -92,7 +92,11 @@ def _get_support_func_locator() -> _SupportFuncLocator:
                      (np.float32, c_lib.convertToMLIRSparseTensorF32,
                       c_lib.convertFromMLIRSparseTensorF32),
                      (np.float64, c_lib.convertToMLIRSparseTensorF64,
-                      c_lib.convertFromMLIRSparseTensorF64)]
+                      c_lib.convertFromMLIRSparseTensorF64),
+                     (np.complex64, c_lib.convertToMLIRSparseTensorC32,
+                      c_lib.convertFromMLIRSparseTensorC32),
+                     (np.complex128, c_lib.convertToMLIRSparseTensorC64,
+                      c_lib.convertFromMLIRSparseTensorC64)]
   except Exception as e:
     raise ValueError(f"Missing supporting function: {e}") from e
   for i, info in enumerate(support_types):
@@ -134,14 +138,15 @@ def sparse_tensor_to_coo_tensor(
   rank = ctypes.c_ulonglong(0)
   nse = ctypes.c_ulonglong(0)
   shape = ctypes.POINTER(ctypes.c_ulonglong)()
-  values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
+
+  values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
   indices = ctypes.POINTER(ctypes.c_ulonglong)()
   convert_from(sparse_tensor, ctypes.byref(rank), ctypes.byref(nse),
                ctypes.byref(shape), ctypes.byref(values), ctypes.byref(indices))
 
   # Convert the returned values to the corresponding numpy types.
   shape = np.ctypeslib.as_array(shape, shape=[rank.value])
-  values = np.ctypeslib.as_array(values, shape=[nse.value])
+  values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
   indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
   return rank.value, nse.value, shape, values, indices
 
@@ -175,7 +180,7 @@ def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
   nse = ctypes.c_ulonglong(len(np_values))
   shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
   values = np_values.ctypes.data_as(
-      ctypes.POINTER(np.ctypeslib.as_ctypes_type(np_values.dtype)))
+      ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype))))
   indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
 
   perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))


        


More information about the Mlir-commits mailing list