[Mlir-commits] [mlir] 6438783 - [mlir][sparse] provide more types for external to/from MLIR routines
Aart Bik
llvmlistbot at llvm.org
Fri Feb 18 13:37:01 PST 2022
Author: Aart Bik
Date: 2022-02-18T13:36:52-08:00
New Revision: 6438783fdaf1a89bcc0945c3c03455793d802352
URL: https://github.com/llvm/llvm-project/commit/6438783fdaf1a89bcc0945c3c03455793d802352
DIFF: https://github.com/llvm/llvm-project/commit/6438783fdaf1a89bcc0945c3c03455793d802352.diff
LOG: [mlir][sparse] provide more types for external to/from MLIR routines
These routines will need to be specialized a lot more based on value types,
index types, pointer types, and permutation/dimension ordering. This is a
careful first step, providing some functionality needed in PyTACO bridge.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D120154
Added:
Modified:
mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
Removed:
################################################################################
diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 665dd8663a6c2..a93836cefdc26 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -717,10 +717,15 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
/// Writes the sparse tensor to extended FROSTT format.
template <typename V>
-void outSparseTensor(const SparseTensorCOO<V> &tensor, char *filename) {
- auto &sizes = tensor.getSizes();
- auto &elements = tensor.getElements();
- uint64_t rank = tensor.getRank();
+void outSparseTensor(void *tensor, void *dest, bool sort) {
+ assert(tensor && dest);
+ auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
+ if (sort)
+ coo->sort();
+ char *filename = static_cast<char *>(dest);
+ auto &sizes = coo->getSizes();
+ auto &elements = coo->getElements();
+ uint64_t rank = coo->getRank();
uint64_t nnz = elements.size();
std::fstream file;
file.open(filename, std::ios_base::out | std::ios_base::trunc);
@@ -738,6 +743,67 @@ void outSparseTensor(const SparseTensorCOO<V> &tensor, char *filename) {
file.flush();
file.close();
assert(file.good());
+ delete coo;
+}
+
+/// Initializes sparse tensor from an external COO-flavored format.
+template <typename V>
+SparseTensorStorage<uint64_t, uint64_t, V> *
+toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
+ uint64_t *indices) {
+ // Setup all-dims compressed and default ordering.
+ std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed);
+ std::vector<uint64_t> perm(rank);
+ std::iota(perm.begin(), perm.end(), 0);
+ // Convert external format to internal COO.
+ auto *tensor =
+ SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm.data(), nse);
+ std::vector<uint64_t> idx(rank);
+ for (uint64_t i = 0, base = 0; i < nse; i++) {
+ for (uint64_t r = 0; r < rank; r++)
+ idx[r] = indices[base + r];
+ tensor->add(idx, values[i]);
+ base += rank;
+ }
+ // Return sparse tensor storage format as opaque pointer.
+ return SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
+ rank, shape, perm.data(), sparse.data(), tensor);
+}
+
+/// Converts a sparse tensor to an external COO-flavored format.
+template <typename V>
+void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
+ uint64_t **pShape, V **pValues, uint64_t **pIndices) {
+ auto sparseTensor =
+ static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
+ uint64_t rank = sparseTensor->getRank();
+ std::vector<uint64_t> perm(rank);
+ std::iota(perm.begin(), perm.end(), 0);
+ SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
+
+ const std::vector<Element<V>> &elements = coo->getElements();
+ uint64_t nse = elements.size();
+
+ uint64_t *shape = new uint64_t[rank];
+ for (uint64_t i = 0; i < rank; i++)
+ shape[i] = coo->getSizes()[i];
+
+ V *values = new V[nse];
+ uint64_t *indices = new uint64_t[rank * nse];
+
+ for (uint64_t i = 0, base = 0; i < nse; i++) {
+ values[i] = elements[i].value;
+ for (uint64_t j = 0; j < rank; j++)
+ indices[base + j] = elements[i].indices[j];
+ base += rank;
+ }
+
+ delete coo;
+ *pRank = rank;
+ *pNse = nse;
+ *pShape = shape;
+ *pValues = values;
+ *pIndices = indices;
}
} // namespace
@@ -873,17 +939,6 @@ extern "C" {
cursor, values, filled, added, count); \
}
-#define IMPL_OUT(NAME, V) \
- void NAME(void *tensor, void *dest, bool sort) { \
- assert(tensor &&dest); \
- auto coo = static_cast<SparseTensorCOO<V> *>(tensor); \
- if (sort) \
- coo->sort(); \
- char *filename = static_cast<char *>(dest); \
- outSparseTensor<V>(*coo, filename); \
- delete coo; \
- }
-
// Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
// can safely rewrite kIndex to kU64. We make this assertion to guarantee
// that this file cannot get out of sync with its header.
@@ -1048,8 +1103,7 @@ IMPL_GETNEXT(getNextI32, int32_t)
IMPL_GETNEXT(getNextI16, int16_t)
IMPL_GETNEXT(getNextI8, int8_t)
-/// Helper to insert elements in lexicographical index order, one per value
-/// type.
+/// Insert elements in lexicographical index order, one per value type.
IMPL_LEXINSERT(lexInsertF64, double)
IMPL_LEXINSERT(lexInsertF32, float)
IMPL_LEXINSERT(lexInsertI64, int64_t)
@@ -1057,7 +1111,7 @@ IMPL_LEXINSERT(lexInsertI32, int32_t)
IMPL_LEXINSERT(lexInsertI16, int16_t)
IMPL_LEXINSERT(lexInsertI8, int8_t)
-/// Helper to insert using expansion, one per value type.
+/// Insert using expansion, one per value type.
IMPL_EXPINSERT(expInsertF64, double)
IMPL_EXPINSERT(expInsertF32, float)
IMPL_EXPINSERT(expInsertI64, int64_t)
@@ -1065,14 +1119,6 @@ IMPL_EXPINSERT(expInsertI32, int32_t)
IMPL_EXPINSERT(expInsertI16, int16_t)
IMPL_EXPINSERT(expInsertI8, int8_t)
-/// Helper to output a sparse tensor, one per value type.
-IMPL_OUT(outSparseTensorF64, double)
-IMPL_OUT(outSparseTensorF32, float)
-IMPL_OUT(outSparseTensorI64, int64_t)
-IMPL_OUT(outSparseTensorI32, int32_t)
-IMPL_OUT(outSparseTensorI16, int16_t)
-IMPL_OUT(outSparseTensorI8, int8_t)
-
#undef CASE
#undef IMPL_SPARSEVALUES
#undef IMPL_GETOVERHEAD
@@ -1080,7 +1126,26 @@ IMPL_OUT(outSparseTensorI8, int8_t)
#undef IMPL_GETNEXT
#undef IMPL_LEXINSERT
#undef IMPL_EXPINSERT
-#undef IMPL_OUT
+
+/// Output a sparse tensor, one per value type.
+void outSparseTensorF64(void *tensor, void *dest, bool sort) {
+ return outSparseTensor<double>(tensor, dest, sort);
+}
+void outSparseTensorF32(void *tensor, void *dest, bool sort) {
+ return outSparseTensor<float>(tensor, dest, sort);
+}
+void outSparseTensorI64(void *tensor, void *dest, bool sort) {
+ return outSparseTensor<int64_t>(tensor, dest, sort);
+}
+void outSparseTensorI32(void *tensor, void *dest, bool sort) {
+ return outSparseTensor<int32_t>(tensor, dest, sort);
+}
+void outSparseTensorI16(void *tensor, void *dest, bool sort) {
+ return outSparseTensor<int16_t>(tensor, dest, sort);
+}
+void outSparseTensorI8(void *tensor, void *dest, bool sort) {
+ return outSparseTensor<int8_t>(tensor, dest, sort);
+}
//===----------------------------------------------------------------------===//
//
@@ -1134,27 +1199,16 @@ void delSparseTensor(void *tensor) {
/// values = [1.0, 5.0, 3.0]
/// indices = [ 0, 0, 1, 1, 1, 2]
//
-// TODO: for now f64 tensors only, no dim ordering, all dimensions compressed
+// TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
+// compressed
//
-void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
- double *values, uint64_t *indices) {
- // Setup all-dims compressed and default ordering.
- std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed);
- std::vector<uint64_t> perm(rank);
- std::iota(perm.begin(), perm.end(), 0);
- // Convert external format to internal COO.
- SparseTensorCOO<double> *tensor = SparseTensorCOO<double>::newSparseTensorCOO(
- rank, shape, perm.data(), nse);
- std::vector<uint64_t> idx(rank);
- for (uint64_t i = 0, base = 0; i < nse; i++) {
- for (uint64_t r = 0; r < rank; r++)
- idx[r] = indices[base + r];
- tensor->add(idx, values[i]);
- base += rank;
- }
- // Return sparse tensor storage format as opaque pointer.
- return SparseTensorStorage<uint64_t, uint64_t, double>::newSparseTensor(
- rank, shape, perm.data(), sparse.data(), tensor);
+void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
+ double *values, uint64_t *indices) {
+ return toMLIRSparseTensor<double>(rank, nse, shape, values, indices);
+}
+void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
+ float *values, uint64_t *indices) {
+ return toMLIRSparseTensor<float>(rank, nse, shape, values, indices);
}
/// Converts a sparse tensor to COO-flavored format expressed using C-style
@@ -1174,41 +1228,18 @@ void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
// SparseTensorCOO, then to the output. We may want to reduce the number of
// copies.
//
-// TODO: for now f64 tensors only, no dim ordering, all dimensions compressed
+// TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
+// compressed
//
-void convertFromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
- uint64_t **pShape, double **pValues,
- uint64_t **pIndices) {
- SparseTensorStorage<uint64_t, uint64_t, double> *sparseTensor =
- static_cast<SparseTensorStorage<uint64_t, uint64_t, double> *>(tensor);
- uint64_t rank = sparseTensor->getRank();
- std::vector<uint64_t> perm(rank);
- std::iota(perm.begin(), perm.end(), 0);
- SparseTensorCOO<double> *coo = sparseTensor->toCOO(perm.data());
-
- const std::vector<Element<double>> &elements = coo->getElements();
- uint64_t nse = elements.size();
-
- uint64_t *shape = new uint64_t[rank];
- for (uint64_t i = 0; i < rank; i++)
- shape[i] = coo->getSizes()[i];
-
- double *values = new double[nse];
- uint64_t *indices = new uint64_t[rank * nse];
-
- for (uint64_t i = 0, base = 0; i < nse; i++) {
- values[i] = elements[i].value;
- for (uint64_t j = 0; j < rank; j++)
- indices[base + j] = elements[i].indices[j];
- base += rank;
- }
-
- delete coo;
- *pRank = rank;
- *pNse = nse;
- *pShape = shape;
- *pValues = values;
- *pIndices = indices;
+void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
+ uint64_t *pNse, uint64_t **pShape,
+ double **pValues, uint64_t **pIndices) {
+ fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
+}
+void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
+ uint64_t *pNse, uint64_t **pShape,
+ float **pValues, uint64_t **pIndices) {
+ fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
} // extern "C"
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py b/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
index d238e6fdb79b4..f5b0ab60e85e9 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
@@ -28,9 +28,9 @@ def _get_c_shared_lib(lib_name: str):
c_lib = ctypes.CDLL(lib_name)
try:
- c_lib.convertFromMLIRSparseTensor.restype = ctypes.c_void_p
+ c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
except Exception as e:
- raise ValueError('Missing function convertFromMLIRSparseTensor from '
+ raise ValueError('Missing function convertFromMLIRSparseTensorF64 from '
f'the C shared library: {e} ') from e
return c_lib
@@ -64,9 +64,10 @@ def sparse_tensor_to_coo_tensor(support_lib, sparse, dtype):
shape = ctypes.POINTER(ctypes.c_ulonglong)()
values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
indices = ctypes.POINTER(ctypes.c_ulonglong)()
- c_lib.convertFromMLIRSparseTensor(sparse, ctypes.byref(rank),
- ctypes.byref(nse), ctypes.byref(shape),
- ctypes.byref(values), ctypes.byref(indices))
+ c_lib.convertFromMLIRSparseTensorF64(sparse, ctypes.byref(rank),
+ ctypes.byref(nse), ctypes.byref(shape),
+ ctypes.byref(values),
+ ctypes.byref(indices))
# Convert the returned values to the corresponding numpy types.
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
values = np.ctypeslib.as_array(values, shape=[nse.value])
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
index 62cd6baff6388..62aa98ee8aaf8 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
@@ -55,15 +55,15 @@ def _get_c_shared_lib() -> ctypes.CDLL:
c_lib = ctypes.CDLL(_get_support_lib_name())
try:
- c_lib.convertToMLIRSparseTensor.restype = ctypes.c_void_p
+ c_lib.convertToMLIRSparseTensorF64.restype = ctypes.c_void_p
except Exception as e:
- raise ValueError("Missing function convertToMLIRSparseTensor from "
+ raise ValueError("Missing function convertToMLIRSparseTensorF64 from "
f"the supporting C shared library: {e} ") from e
try:
- c_lib.convertFromMLIRSparseTensor.restype = ctypes.c_void_p
+ c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
except Exception as e:
- raise ValueError("Missing function convertFromMLIRSparseTensor from "
+ raise ValueError("Missing function convertFromMLIRSparseTensorF64 from "
f"the C shared library: {e} ") from e
return c_lib
@@ -100,9 +100,10 @@ def sparse_tensor_to_coo_tensor(
shape = ctypes.POINTER(ctypes.c_ulonglong)()
values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
indices = ctypes.POINTER(ctypes.c_ulonglong)()
- c_lib.convertFromMLIRSparseTensor(sparse_tensor, ctypes.byref(rank),
- ctypes.byref(nse), ctypes.byref(shape),
- ctypes.byref(values), ctypes.byref(indices))
+ c_lib.convertFromMLIRSparseTensorF64(sparse_tensor, ctypes.byref(rank),
+ ctypes.byref(nse), ctypes.byref(shape),
+ ctypes.byref(values),
+ ctypes.byref(indices))
# Convert the returned values to the corresponding numpy types.
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
@@ -138,8 +139,8 @@ def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
c_lib = _get_c_shared_lib()
- ptr = c_lib.convertToMLIRSparseTensor(rank, nse, shape, values, indices)
- assert ptr is not None, "Problem with calling convertToMLIRSparseTensor"
+ ptr = c_lib.convertToMLIRSparseTensorF64(rank, nse, shape, values, indices)
+ assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
return ptr
More information about the Mlir-commits
mailing list