[Mlir-commits] [mlir] a4c53f8 - [mlir][sparse] Factoring out SparseTensorFile class for readSparseTensorShape
wren romano
llvmlistbot at llvm.org
Tue May 31 13:24:34 PDT 2022
Author: wren romano
Date: 2022-05-31T13:24:28-07:00
New Revision: a4c53f8cd6e4519242690b9e5aa54a928609cf8b
URL: https://github.com/llvm/llvm-project/commit/a4c53f8cd6e4519242690b9e5aa54a928609cf8b
DIFF: https://github.com/llvm/llvm-project/commit/a4c53f8cd6e4519242690b9e5aa54a928609cf8b.diff
LOG: [mlir][sparse] Factoring out SparseTensorFile class for readSparseTensorShape
The primary goal of this change is to define readSparseTensorShape. Whereas the SparseTensorFile class is merely introduced as a way to reduce code duplication along the way.
Depends On D126106
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D126233
Added:
Modified:
mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
index e733544eea35f..51e78c3f27fc2 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
@@ -274,6 +274,11 @@ FOREVERY_V(DECL_DELCOO)
/// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id);
+/// Helper function to read the header of a file and return the
+/// shape/sizes, without parsing the elements of the file.
+MLIR_CRUNNERUTILS_EXPORT void readSparseTensorShape(char *filename,
+ std::vector<uint64_t> *out);
+
/// Initializes sparse tensor from a COO-flavored format expressed using
/// C-style data structures. The expected parameters are:
///
diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 430ba85affa31..da3b705e9e9db 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -88,9 +88,11 @@ static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
exit(1); \
}
-// TODO: adjust this so it can be used by `openSparseTensorCOO` too.
-// That version doesn't have the permutation, and the `dimSizes` are
-// a pointer/C-array rather than `std::vector`.
+// TODO: try to unify this with `SparseTensorFile::assertMatchesShape`
+// which is used by `openSparseTensorCOO`. It's easy enough to resolve
+// the `std::vector` vs pointer mismatch for `dimSizes`; but it's trickier
+// to resolve the presence/absence of `perm` (without introducing extra
+// overhead), so perhaps the code duplication is unavoidable.
//
/// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
/// semantic-order to target-order) are a refinement of the desired `shape`
@@ -1099,9 +1101,128 @@ static char *toLower(char *token) {
return token;
}
+/// This class abstracts over the information stored in file headers,
+/// as well as providing the buffers and methods for parsing those headers.
+class SparseTensorFile final {
+public:
+ explicit SparseTensorFile(char *filename) : filename(filename) {
+ assert(filename && "Received nullptr for filename");
+ }
+
+ // Disallows copying, to avoid duplicating the `file` pointer.
+ SparseTensorFile(const SparseTensorFile &) = delete;
+ SparseTensorFile &operator=(const SparseTensorFile &) = delete;
+
+ // This dtor tries to avoid leaking the `file`. (Though it's better
+ // to call `closeFile` explicitly when possible, since there are
+ // circumstances where dtors are not called reliably.)
+ ~SparseTensorFile() { closeFile(); }
+
+ /// Opens the file for reading.
+ void openFile() {
+ if (file)
+ FATAL("Already opened file %s\n", filename);
+ file = fopen(filename, "r");
+ if (!file)
+ FATAL("Cannot find file %s\n", filename);
+ }
+
+ /// Closes the file.
+ void closeFile() {
+ if (file) {
+ fclose(file);
+ file = nullptr;
+ }
+ }
+
+ // TODO(wrengr/bixia): figure out how to reorganize the element-parsing
+ // loop of `openSparseTensorCOO` into methods of this class, so we can
+ // avoid leaking access to the `line` pointer (both for general hygiene
+ // and because we can't mark it const due to the second argument of
+ // `strtoul`/`strtoud` being `char * *restrict` rather than
+ // `char const* *restrict`).
+ //
+ /// Attempts to read a line from the file.
+ char *readLine() {
+ if (fgets(line, kColWidth, file))
+ return line;
+ FATAL("Cannot read next line of %s\n", filename);
+ }
+
+ /// Reads and parses the file's header.
+ void readHeader() {
+ assert(file && "Attempt to readHeader() before openFile()");
+ if (strstr(filename, ".mtx"))
+ readMMEHeader();
+ else if (strstr(filename, ".tns"))
+ readExtFROSTTHeader();
+ else
+ FATAL("Unknown format %s\n", filename);
+ assert(isValid && "Failed to read the header");
+ }
+
+ /// Gets the MME "pattern" property setting. Is only valid after
+ /// parsing the header.
+ bool isPattern() const {
+ assert(isValid && "Attempt to isPattern() before readHeader()");
+ return isPattern_;
+ }
+
+ /// Gets the MME "symmetric" property setting. Is only valid after
+ /// parsing the header.
+ bool isSymmetric() const {
+ assert(isValid && "Attempt to isSymmetric() before readHeader()");
+ return isSymmetric_;
+ }
+
+ /// Gets the rank of the tensor. Is only valid after parsing the header.
+ uint64_t getRank() const {
+ assert(isValid && "Attempt to getRank() before readHeader()");
+ return idata[0];
+ }
+
+ /// Gets the number of non-zeros. Is only valid after parsing the header.
+ uint64_t getNNZ() const {
+ assert(isValid && "Attempt to getNNZ() before readHeader()");
+ return idata[1];
+ }
+
+ /// Gets the dimension-sizes array. The pointer itself is always
+ /// valid; however, the values stored therein are only valid after
+ /// parsing the header.
+ const uint64_t *getDimSizes() const { return idata + 2; }
+
+ /// Safely gets the size of the given dimension. Is only valid
+ /// after parsing the header.
+ uint64_t getDimSize(uint64_t d) const {
+ assert(d < getRank());
+ return idata[2 + d];
+ }
+
+ /// Asserts the shape subsumes the actual dimension sizes. Is only
+ /// valid after parsing the header.
+ void assertMatchesShape(uint64_t rank, const uint64_t *shape) const {
+ assert(rank == getRank() && "Rank mismatch");
+ for (uint64_t r = 0; r < rank; r++)
+ assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
+ "Dimension size mismatch");
+ }
+
+private:
+ void readMMEHeader();
+ void readExtFROSTTHeader();
+
+ const char *filename;
+ FILE *file = nullptr;
+ bool isValid = false;
+ bool isPattern_ = false;
+ bool isSymmetric_ = false;
+ uint64_t idata[512];
+ char line[kColWidth];
+};
+
/// Read the MME header of a general sparse matrix of type real.
-static void readMMEHeader(FILE *file, char *filename, char *line,
- uint64_t *idata, bool *isPattern, bool *isSymmetric) {
+void SparseTensorFile::readMMEHeader() {
char header[64];
char object[64];
char format[64];
@@ -1112,19 +1233,18 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
symmetry) != 5)
FATAL("Corrupt header in %s\n", filename);
// Set properties
- *isPattern = (strcmp(toLower(field), "pattern") == 0);
- *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
+ isPattern_ = (strcmp(toLower(field), "pattern") == 0);
+ isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0);
// Make sure this is a general sparse matrix.
if (strcmp(toLower(header), "%%matrixmarket") ||
strcmp(toLower(object), "matrix") ||
strcmp(toLower(format), "coordinate") ||
- (strcmp(toLower(field), "real") && !(*isPattern)) ||
- (strcmp(toLower(symmetry), "general") && !(*isSymmetric)))
+ (strcmp(toLower(field), "real") && !isPattern_) ||
+ (strcmp(toLower(symmetry), "general") && !isSymmetric_))
FATAL("Cannot find a general sparse matrix in %s\n", filename);
// Skip comments.
while (true) {
- if (!fgets(line, kColWidth, file))
- FATAL("Cannot find data in %s\n", filename);
+ readLine();
if (line[0] != '%')
break;
}
@@ -1133,18 +1253,17 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
idata + 1) != 3)
FATAL("Cannot find size in %s\n", filename);
+ isValid = true;
}
/// Read the "extended" FROSTT header. Although not part of the documented
/// format, we assume that the file starts with optional comments followed
/// by two lines that define the rank, the number of nonzeros, and the
/// dimensions sizes (one per rank) of the sparse tensor.
-static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
- uint64_t *idata) {
+void SparseTensorFile::readExtFROSTTHeader() {
// Skip comments.
while (true) {
- if (!fgets(line, kColWidth, file))
- FATAL("Cannot find data in %s\n", filename);
+ readLine();
if (line[0] != '#')
break;
}
@@ -1155,7 +1274,8 @@ static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
for (uint64_t r = 0; r < idata[0]; r++)
if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
FATAL("Cannot find dimension size %s\n", filename);
- fgets(line, kColWidth, file); // end of line
+ readLine(); // end of line
+ isValid = true;
}
/// Reads a sparse tensor with the given filename into a memory-resident
@@ -1164,38 +1284,19 @@ template <typename V>
static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
const uint64_t *shape,
const uint64_t *perm) {
- // Open the file.
- assert(filename && "Received nullptr for filename");
- FILE *file = fopen(filename, "r");
- if (!file)
- FATAL("Cannot find file %s\n", filename);
- // Perform some file format dependent set up.
- char line[kColWidth];
- uint64_t idata[512];
- bool isPattern = false;
- bool isSymmetric = false;
- if (strstr(filename, ".mtx")) {
- readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
- } else if (strstr(filename, ".tns")) {
- readExtFROSTTHeader(file, filename, line, idata);
- } else {
- FATAL("Unknown format %s\n", filename);
- }
+ SparseTensorFile stfile(filename);
+ stfile.openFile();
+ stfile.readHeader();
+ stfile.assertMatchesShape(rank, shape);
// Prepare sparse tensor object with per-dimension sizes
// and the number of nonzeros as initial capacity.
- assert(rank == idata[0] && "rank mismatch");
- uint64_t nnz = idata[1];
- for (uint64_t r = 0; r < rank; r++)
- assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
- "dimension size mismatch");
- SparseTensorCOO<V> *tensor =
- SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
+ uint64_t nnz = stfile.getNNZ();
+ auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, stfile.getDimSizes(),
+ perm, nnz);
// Read all nonzero elements.
std::vector<uint64_t> indices(rank);
for (uint64_t k = 0; k < nnz; k++) {
- if (!fgets(line, kColWidth, file))
- FATAL("Cannot find next line of data in %s\n", filename);
- char *linePtr = line;
+ char *linePtr = stfile.readLine();
for (uint64_t r = 0; r < rank; r++) {
uint64_t idx = strtoul(linePtr, &linePtr, 10);
// Add 0-based index.
@@ -1204,17 +1305,18 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
// The external formats always store the numerical values with the type
// double, but we cast these values to the sparse tensor object type.
// For a pattern tensor, we arbitrarily pick the value 1 for all entries.
- double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
- tensor->add(indices, value);
+ double value = stfile.isPattern() ? 1.0 : strtod(linePtr, &linePtr);
+ // TODO: <https://github.com/llvm/llvm-project/issues/54179>
+ coo->add(indices, value);
// We currently chose to deal with symmetric matrices by fully constructing
// them. In the future, we may want to make symmetry implicit for storage
// reasons.
- if (isSymmetric && indices[0] != indices[1])
- tensor->add({indices[1], indices[0]}, value);
+ if (stfile.isSymmetric() && indices[0] != indices[1])
+ coo->add({indices[1], indices[0]}, value);
}
// Close the file and return tensor.
- fclose(file);
- return tensor;
+ stfile.closeFile();
+ return coo;
}
/// Writes the sparse tensor to `dest` in extended FROSTT format.
@@ -1670,6 +1772,18 @@ char *getTensorFilename(index_type id) {
return env;
}
+void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
+ assert(out && "Received nullptr for out-parameter");
+ SparseTensorFile stfile(filename);
+ stfile.openFile();
+ stfile.readHeader();
+ stfile.closeFile();
+ const uint64_t rank = stfile.getRank();
+ const uint64_t *dimSizes = stfile.getDimSizes();
+ out->reserve(rank);
+ out->assign(dimSizes, dimSizes + rank);
+}
+
// TODO: generalize beyond 64-bit indices.
#define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V) \
void *convertToMLIRSparseTensor##VNAME( \
More information about the Mlir-commits
mailing list