[Mlir-commits] [mlir] 1313f5d - [mlir][sparse] Restyling macros in the runtime library
wren romano
llvmlistbot at llvm.org
Mon May 16 16:43:45 PDT 2022
Author: wren romano
Date: 2022-05-16T16:43:39-07:00
New Revision: 1313f5d3071c5aee6eaf3c366747d44585522fb4
URL: https://github.com/llvm/llvm-project/commit/1313f5d3071c5aee6eaf3c366747d44585522fb4
DIFF: https://github.com/llvm/llvm-project/commit/1313f5d3071c5aee6eaf3c366747d44585522fb4.diff
LOG: [mlir][sparse] Restyling macros in the runtime library
In addition to reducing code repetition, this also helps ensure that the various API functions follow the naming convention of mlir::sparse_tensor::primaryTypeFunctionSuffix (e.g., due to typos in the repetitious code).
Depends On D125428
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D125431
Added:
Modified:
mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
Removed:
################################################################################
diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index c5098d531ffef..b8731ff1c745b 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -234,6 +234,29 @@ struct SparseTensorCOO final {
unsigned iteratorPos = 0;
};
+// See <https://en.wikipedia.org/wiki/X_Macro>
+//
+// `FOREVERY_SIMPLEX_V` only specifies the non-complex `V` types, because
+// the ABI for complex types has compiler/architecture dependent complexities
+// we need to work around. Namely, when a function takes a parameter of
+// C/C++ type `complex32` (per se), then there is additional padding that
+// causes it not to match the LLVM type `!llvm.struct<(f32, f32)>`. This
+// only happens with the `complex32` type itself, not with pointers/arrays
+// of complex values. So far `complex64` doesn't exhibit this ABI
+// incompatibility, but we exclude it anyways just to be safe.
+#define FOREVERY_SIMPLEX_V(DO) \
+ DO(F64, double) \
+ DO(F32, float) \
+ DO(I64, int64_t) \
+ DO(I32, int32_t) \
+ DO(I16, int16_t) \
+ DO(I8, int8_t)
+
+#define FOREVERY_V(DO) \
+ FOREVERY_SIMPLEX_V(DO) \
+ DO(C64, complex64) \
+ DO(C32, complex32)
+
// Forward.
template <typename V>
class SparseTensorEnumeratorBase;
@@ -298,38 +321,13 @@ class SparseTensorStorageBase {
}
/// Allocate a new enumerator.
- virtual void newEnumerator(SparseTensorEnumeratorBase<double> **, uint64_t,
- const uint64_t *) const {
- fatal("enumf64");
- }
- virtual void newEnumerator(SparseTensorEnumeratorBase<float> **, uint64_t,
- const uint64_t *) const {
- fatal("enumf32");
- }
- virtual void newEnumerator(SparseTensorEnumeratorBase<int64_t> **, uint64_t,
- const uint64_t *) const {
- fatal("enumi64");
- }
- virtual void newEnumerator(SparseTensorEnumeratorBase<int32_t> **, uint64_t,
- const uint64_t *) const {
- fatal("enumi32");
- }
- virtual void newEnumerator(SparseTensorEnumeratorBase<int16_t> **, uint64_t,
- const uint64_t *) const {
- fatal("enumi16");
- }
- virtual void newEnumerator(SparseTensorEnumeratorBase<int8_t> **, uint64_t,
- const uint64_t *) const {
- fatal("enumi8");
- }
- virtual void newEnumerator(SparseTensorEnumeratorBase<complex64> **, uint64_t,
- const uint64_t *) const {
- fatal("enumc64");
- }
- virtual void newEnumerator(SparseTensorEnumeratorBase<complex32> **, uint64_t,
- const uint64_t *) const {
- fatal("enumc32");
+#define DECL_NEWENUMERATOR(VNAME, V) \
+ virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t, \
+ const uint64_t *) const { \
+ fatal("newEnumerator" #VNAME); \
}
+ FOREVERY_V(DECL_NEWENUMERATOR)
+#undef DECL_NEWENUMERATOR
/// Overhead storage.
virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
@@ -342,52 +340,24 @@ class SparseTensorStorageBase {
virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
/// Primary storage.
- virtual void getValues(std::vector<double> **) { fatal("valf64"); }
- virtual void getValues(std::vector<float> **) { fatal("valf32"); }
- virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
- virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
- virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
- virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
- virtual void getValues(std::vector<complex64> **) { fatal("valc64"); }
- virtual void getValues(std::vector<complex32> **) { fatal("valc32"); }
+#define DECL_GETVALUES(VNAME, V) \
+ virtual void getValues(std::vector<V> **) { fatal("getValues" #VNAME); }
+ FOREVERY_V(DECL_GETVALUES)
+#undef DECL_GETVALUES
/// Element-wise insertion in lexicographic index order.
- virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
- virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
- virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
- virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
- virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
- virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
- virtual void lexInsert(const uint64_t *, complex64) { fatal("insc64"); }
- virtual void lexInsert(const uint64_t *, complex32) { fatal("insc32"); }
+#define DECL_LEXINSERT(VNAME, V) \
+ virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); }
+ FOREVERY_V(DECL_LEXINSERT)
+#undef DECL_LEXINSERT
/// Expanded insertion.
- virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
- fatal("expf64");
- }
- virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
- fatal("expf32");
- }
- virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
- fatal("expi64");
- }
- virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
- fatal("expi32");
- }
- virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
- fatal("expi16");
- }
- virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
- fatal("expi8");
- }
- virtual void expInsert(uint64_t *, complex64 *, bool *, uint64_t *,
- uint64_t) {
- fatal("expc64");
- }
- virtual void expInsert(uint64_t *, complex32 *, bool *, uint64_t *,
- uint64_t) {
- fatal("expc32");
+#define DECL_EXPINSERT(VNAME, V) \
+ virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \
+ fatal("expInsert" #VNAME); \
}
+ FOREVERY_V(DECL_EXPINSERT)
+#undef DECL_EXPINSERT
/// Finishes insertion.
virtual void endInsert() = 0;
@@ -1440,17 +1410,23 @@ extern "C" {
}
#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
+// TODO(D125432): move `_mlir_ciface_newSparseTensor` closer to these
+// macro definitions, but as a separate change so as not to muddy the
diff .
-#define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \
- void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \
+/// Methods that provide direct access to values.
+#define IMPL_SPARSEVALUES(VNAME, V) \
+ void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \
+ void *tensor) { \
assert(ref &&tensor); \
- std::vector<TYPE> *v; \
- static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v); \
+ std::vector<V> *v; \
+ static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \
ref->basePtr = ref->data = v->data(); \
ref->offset = 0; \
ref->sizes[0] = v->size(); \
ref->strides[0] = 1; \
}
+FOREVERY_V(IMPL_SPARSEVALUES)
+#undef IMPL_SPARSEVALUES
#define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \
@@ -1463,12 +1439,27 @@ extern "C" {
ref->sizes[0] = v->size(); \
ref->strides[0] = 1; \
}
+/// Methods that provide direct access to pointers.
+IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
+IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
+IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
+IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
+IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
+
+/// Methods that provide direct access to indices.
+IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
+IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
+IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
+IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
+IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
+#undef IMPL_GETOVERHEAD
-#define IMPL_ADDELT(NAME, TYPE) \
- void *_mlir_ciface_##NAME(void *tensor, TYPE value, \
- StridedMemRefType<index_type, 1> *iref, \
- StridedMemRefType<index_type, 1> *pref) { \
- assert(tensor &&iref &&pref); \
+/// Helper to add value to coordinate scheme, one per value type.
+#define IMPL_ADDELT(VNAME, V) \
+ void *_mlir_ciface_addElt##VNAME(void *coo, V value, \
+ StridedMemRefType<index_type, 1> *iref, \
+ StridedMemRefType<index_type, 1> *pref) { \
+ assert(coo &&iref &&pref); \
assert(iref->strides[0] == 1 && pref->strides[0] == 1); \
assert(iref->sizes[0] == pref->sizes[0]); \
const index_type *indx = iref->data + iref->offset; \
@@ -1477,21 +1468,33 @@ extern "C" {
std::vector<index_type> indices(isize); \
for (uint64_t r = 0; r < isize; r++) \
indices[perm[r]] = indx[r]; \
- static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value); \
- return tensor; \
+ static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value); \
+ return coo; \
}
+FOREVERY_SIMPLEX_V(IMPL_ADDELT)
+// `complex64` apparently doesn't encounter any ABI issues (yet).
+IMPL_ADDELT(C64, complex64)
+// TODO: cleaner way to avoid ABI padding problem?
+IMPL_ADDELT(C32ABI, complex32)
+void *_mlir_ciface_addEltC32(void *tensor, float r, float i,
+ StridedMemRefType<index_type, 1> *iref,
+ StridedMemRefType<index_type, 1> *pref) {
+ return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref);
+}
+#undef IMPL_ADDELT
-#define IMPL_GETNEXT(NAME, V) \
- bool _mlir_ciface_##NAME(void *tensor, \
- StridedMemRefType<index_type, 1> *iref, \
- StridedMemRefType<V, 0> *vref) { \
- assert(tensor &&iref &&vref); \
+/// Helper to enumerate elements of coordinate scheme, one per value type.
+#define IMPL_GETNEXT(VNAME, V) \
+ bool _mlir_ciface_getNext##VNAME(void *coo, \
+ StridedMemRefType<index_type, 1> *iref, \
+ StridedMemRefType<V, 0> *vref) { \
+ assert(coo &&iref &&vref); \
assert(iref->strides[0] == 1); \
index_type *indx = iref->data + iref->offset; \
V *value = vref->data + vref->offset; \
const uint64_t isize = iref->sizes[0]; \
- auto iter = static_cast<SparseTensorCOO<V> *>(tensor); \
- const Element<V> *elem = iter->getNext(); \
+ const Element<V> *elem = \
+ static_cast<SparseTensorCOO<V> *>(coo)->getNext(); \
if (elem == nullptr) \
return false; \
for (uint64_t r = 0; r < isize; r++) \
@@ -1499,19 +1502,34 @@ extern "C" {
*value = elem->value; \
return true; \
}
+FOREVERY_V(IMPL_GETNEXT)
+#undef IMPL_GETNEXT
-#define IMPL_LEXINSERT(NAME, V) \
- void _mlir_ciface_##NAME(void *tensor, \
- StridedMemRefType<index_type, 1> *cref, V val) { \
+/// Insert elements in lexicographical index order, one per value type.
+#define IMPL_LEXINSERT(VNAME, V) \
+ void _mlir_ciface_lexInsert##VNAME( \
+ void *tensor, StridedMemRefType<index_type, 1> *cref, V val) { \
assert(tensor &&cref); \
assert(cref->strides[0] == 1); \
index_type *cursor = cref->data + cref->offset; \
assert(cursor); \
static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val); \
}
+FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
+// `complex64` apparently doesn't encounter any ABI issues (yet).
+IMPL_LEXINSERT(C64, complex64)
+// TODO: cleaner way to avoid ABI padding problem?
+IMPL_LEXINSERT(C32ABI, complex32)
+void _mlir_ciface_lexInsertC32(void *tensor,
+ StridedMemRefType<index_type, 1> *cref, float r,
+ float i) {
+ _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
+}
+#undef IMPL_LEXINSERT
-#define IMPL_EXPINSERT(NAME, V) \
- void _mlir_ciface_##NAME( \
+/// Insert using expansion, one per value type.
+#define IMPL_EXPINSERT(VNAME, V) \
+ void _mlir_ciface_expInsert##VNAME( \
void *tensor, StridedMemRefType<index_type, 1> *cref, \
StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \
StridedMemRefType<index_type, 1> *aref, index_type count) { \
@@ -1528,6 +1546,8 @@ extern "C" {
static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \
cursor, values, filled, added, count); \
}
+FOREVERY_V(IMPL_EXPINSERT)
+#undef IMPL_EXPINSERT
// Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
// can safely rewrite kIndex to kU64. We make this assertion to guarantee
@@ -1658,122 +1678,16 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
fputs("unsupported combination of types\n", stderr);
exit(1);
}
-
-/// Methods that provide direct access to pointers.
-IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
-IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
-IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
-IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
-IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
-
-/// Methods that provide direct access to indices.
-IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
-IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
-IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
-IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
-IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
-
-/// Methods that provide direct access to values.
-IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
-IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
-IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
-IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
-IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
-IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
-IMPL_SPARSEVALUES(sparseValuesC64, complex64, getValues)
-IMPL_SPARSEVALUES(sparseValuesC32, complex32, getValues)
-
-/// Helper to add value to coordinate scheme, one per value type.
-IMPL_ADDELT(addEltF64, double)
-IMPL_ADDELT(addEltF32, float)
-IMPL_ADDELT(addEltI64, int64_t)
-IMPL_ADDELT(addEltI32, int32_t)
-IMPL_ADDELT(addEltI16, int16_t)
-IMPL_ADDELT(addEltI8, int8_t)
-IMPL_ADDELT(addEltC64, complex64)
-IMPL_ADDELT(addEltC32ABI, complex32)
-// Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
-// any padding (which seem to happen for complex32 when passed as scalar;
-// all other cases, e.g. pointer to array, work as expected).
-// TODO: cleaner way to avoid ABI padding problem?
-void *_mlir_ciface_addEltC32(void *tensor, float r, float i,
- StridedMemRefType<index_type, 1> *iref,
- StridedMemRefType<index_type, 1> *pref) {
- return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref);
-}
-
-/// Helper to enumerate elements of coordinate scheme, one per value type.
-IMPL_GETNEXT(getNextF64, double)
-IMPL_GETNEXT(getNextF32, float)
-IMPL_GETNEXT(getNextI64, int64_t)
-IMPL_GETNEXT(getNextI32, int32_t)
-IMPL_GETNEXT(getNextI16, int16_t)
-IMPL_GETNEXT(getNextI8, int8_t)
-IMPL_GETNEXT(getNextC64, complex64)
-IMPL_GETNEXT(getNextC32, complex32)
-
-/// Insert elements in lexicographical index order, one per value type.
-IMPL_LEXINSERT(lexInsertF64, double)
-IMPL_LEXINSERT(lexInsertF32, float)
-IMPL_LEXINSERT(lexInsertI64, int64_t)
-IMPL_LEXINSERT(lexInsertI32, int32_t)
-IMPL_LEXINSERT(lexInsertI16, int16_t)
-IMPL_LEXINSERT(lexInsertI8, int8_t)
-IMPL_LEXINSERT(lexInsertC64, complex64)
-IMPL_LEXINSERT(lexInsertC32ABI, complex32)
-// Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
-// any padding (which seem to happen for complex32 when passed as scalar;
-// all other cases, e.g. pointer to array, work as expected).
-// TODO: cleaner way to avoid ABI padding problem?
-void _mlir_ciface_lexInsertC32(void *tensor,
- StridedMemRefType<index_type, 1> *cref, float r,
- float i) {
- _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
-}
-
-/// Insert using expansion, one per value type.
-IMPL_EXPINSERT(expInsertF64, double)
-IMPL_EXPINSERT(expInsertF32, float)
-IMPL_EXPINSERT(expInsertI64, int64_t)
-IMPL_EXPINSERT(expInsertI32, int32_t)
-IMPL_EXPINSERT(expInsertI16, int16_t)
-IMPL_EXPINSERT(expInsertI8, int8_t)
-IMPL_EXPINSERT(expInsertC64, complex64)
-IMPL_EXPINSERT(expInsertC32, complex32)
-
#undef CASE
-#undef IMPL_SPARSEVALUES
-#undef IMPL_GETOVERHEAD
-#undef IMPL_ADDELT
-#undef IMPL_GETNEXT
-#undef IMPL_LEXINSERT
-#undef IMPL_EXPINSERT
+#undef CASE_SECSAME
/// Output a sparse tensor, one per value type.
-void outSparseTensorF64(void *tensor, void *dest, bool sort) {
- return outSparseTensor<double>(tensor, dest, sort);
-}
-void outSparseTensorF32(void *tensor, void *dest, bool sort) {
- return outSparseTensor<float>(tensor, dest, sort);
-}
-void outSparseTensorI64(void *tensor, void *dest, bool sort) {
- return outSparseTensor<int64_t>(tensor, dest, sort);
-}
-void outSparseTensorI32(void *tensor, void *dest, bool sort) {
- return outSparseTensor<int32_t>(tensor, dest, sort);
-}
-void outSparseTensorI16(void *tensor, void *dest, bool sort) {
- return outSparseTensor<int16_t>(tensor, dest, sort);
-}
-void outSparseTensorI8(void *tensor, void *dest, bool sort) {
- return outSparseTensor<int8_t>(tensor, dest, sort);
-}
-void outSparseTensorC64(void *tensor, void *dest, bool sort) {
- return outSparseTensor<complex64>(tensor, dest, sort);
-}
-void outSparseTensorC32(void *tensor, void *dest, bool sort) {
- return outSparseTensor<complex32>(tensor, dest, sort);
-}
+#define IMPL_OUTSPARSETENSOR(VNAME, V) \
+ void outSparseTensor##VNAME(void *coo, void *dest, bool sort) { \
+ return outSparseTensor<V>(coo, dest, sort); \
+ }
+FOREVERY_V(IMPL_OUTSPARSETENSOR)
+#undef IMPL_OUTSPARSETENSOR
//===----------------------------------------------------------------------===//
//
@@ -1817,14 +1731,7 @@ void delSparseTensor(void *tensor) {
void delSparseTensorCOO##VNAME(void *coo) { \
delete static_cast<SparseTensorCOO<V> *>(coo); \
}
-IMPL_DELCOO(F64, double)
-IMPL_DELCOO(F32, float)
-IMPL_DELCOO(I64, int64_t)
-IMPL_DELCOO(I32, int32_t)
-IMPL_DELCOO(I16, int16_t)
-IMPL_DELCOO(I8, int8_t)
-IMPL_DELCOO(C64, complex64)
-IMPL_DELCOO(C32, complex32)
+FOREVERY_V(IMPL_DELCOO)
#undef IMPL_DELCOO
/// Initializes sparse tensor from a COO-flavored format expressed using C-style
@@ -1850,54 +1757,15 @@ IMPL_DELCOO(C32, complex32)
//
// TODO: generalize beyond 64-bit indices.
//
-void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
- double *values, uint64_t *indices,
- uint64_t *perm, uint8_t *sparse) {
- return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm,
- sparse);
-}
-void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
- float *values, uint64_t *indices,
- uint64_t *perm, uint8_t *sparse) {
- return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
- sparse);
-}
-void *convertToMLIRSparseTensorI64(uint64_t rank, uint64_t nse, uint64_t *shape,
- int64_t *values, uint64_t *indices,
- uint64_t *perm, uint8_t *sparse) {
- return toMLIRSparseTensor<int64_t>(rank, nse, shape, values, indices, perm,
- sparse);
-}
-void *convertToMLIRSparseTensorI32(uint64_t rank, uint64_t nse, uint64_t *shape,
- int32_t *values, uint64_t *indices,
- uint64_t *perm, uint8_t *sparse) {
- return toMLIRSparseTensor<int32_t>(rank, nse, shape, values, indices, perm,
- sparse);
-}
-void *convertToMLIRSparseTensorI16(uint64_t rank, uint64_t nse, uint64_t *shape,
- int16_t *values, uint64_t *indices,
- uint64_t *perm, uint8_t *sparse) {
- return toMLIRSparseTensor<int16_t>(rank, nse, shape, values, indices, perm,
- sparse);
-}
-void *convertToMLIRSparseTensorI8(uint64_t rank, uint64_t nse, uint64_t *shape,
- int8_t *values, uint64_t *indices,
- uint64_t *perm, uint8_t *sparse) {
- return toMLIRSparseTensor<int8_t>(rank, nse, shape, values, indices, perm,
- sparse);
-}
-void *convertToMLIRSparseTensorC64(uint64_t rank, uint64_t nse, uint64_t *shape,
- complex64 *values, uint64_t *indices,
- uint64_t *perm, uint8_t *sparse) {
- return toMLIRSparseTensor<complex64>(rank, nse, shape, values, indices, perm,
- sparse);
-}
-void *convertToMLIRSparseTensorC32(uint64_t rank, uint64_t nse, uint64_t *shape,
- complex32 *values, uint64_t *indices,
- uint64_t *perm, uint8_t *sparse) {
- return toMLIRSparseTensor<complex32>(rank, nse, shape, values, indices, perm,
- sparse);
-}
+#define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V) \
+ void *convertToMLIRSparseTensor##VNAME( \
+ uint64_t rank, uint64_t nse, uint64_t *shape, V *values, \
+ uint64_t *indices, uint64_t *perm, uint8_t *sparse) { \
+ return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm, \
+ sparse); \
+ }
+FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
+#undef IMPL_CONVERTTOMLIRSPARSETENSOR
/// Converts a sparse tensor to COO-flavored format expressed using C-style
/// data structures. The expected output parameters are pointers for these
@@ -1919,48 +1787,14 @@ void *convertToMLIRSparseTensorC32(uint64_t rank, uint64_t nse, uint64_t *shape,
// TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
// compressed
//
-void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
- uint64_t *pNse, uint64_t **pShape,
- double **pValues, uint64_t **pIndices) {
- fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
-}
-void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
- uint64_t *pNse, uint64_t **pShape,
- float **pValues, uint64_t **pIndices) {
- fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
-}
-void convertFromMLIRSparseTensorI64(void *tensor, uint64_t *pRank,
- uint64_t *pNse, uint64_t **pShape,
- int64_t **pValues, uint64_t **pIndices) {
- fromMLIRSparseTensor<int64_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
-}
-void convertFromMLIRSparseTensorI32(void *tensor, uint64_t *pRank,
- uint64_t *pNse, uint64_t **pShape,
- int32_t **pValues, uint64_t **pIndices) {
- fromMLIRSparseTensor<int32_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
-}
-void convertFromMLIRSparseTensorI16(void *tensor, uint64_t *pRank,
- uint64_t *pNse, uint64_t **pShape,
- int16_t **pValues, uint64_t **pIndices) {
- fromMLIRSparseTensor<int16_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
-}
-void convertFromMLIRSparseTensorI8(void *tensor, uint64_t *pRank,
- uint64_t *pNse, uint64_t **pShape,
- int8_t **pValues, uint64_t **pIndices) {
- fromMLIRSparseTensor<int8_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
-}
-void convertFromMLIRSparseTensorC64(void *tensor, uint64_t *pRank,
- uint64_t *pNse, uint64_t **pShape,
- complex64 **pValues, uint64_t **pIndices) {
- fromMLIRSparseTensor<complex64>(tensor, pRank, pNse, pShape, pValues,
- pIndices);
-}
-void convertFromMLIRSparseTensorC32(void *tensor, uint64_t *pRank,
- uint64_t *pNse, uint64_t **pShape,
- complex32 **pValues, uint64_t **pIndices) {
- fromMLIRSparseTensor<complex32>(tensor, pRank, pNse, pShape, pValues,
- pIndices);
-}
+#define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V) \
+ void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank, \
+ uint64_t *pNse, uint64_t **pShape, \
+ V **pValues, uint64_t **pIndices) { \
+ fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices); \
+ }
+FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
+#undef IMPL_CONVERTFROMMLIRSPARSETENSOR
} // extern "C"
More information about the Mlir-commits
mailing list