[llvm-branch-commits] [mlir] 5959c28 - [mlir][sparse] add asserts on reading in tensor data
Aart Bik via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jan 20 14:35:13 PST 2021
Author: Aart Bik
Date: 2021-01-20T14:30:13-08:00
New Revision: 5959c28f24856f3d4a1db6b4743c66bdc6dcd735
URL: https://github.com/llvm/llvm-project/commit/5959c28f24856f3d4a1db6b4743c66bdc6dcd735
DIFF: https://github.com/llvm/llvm-project/commit/5959c28f24856f3d4a1db6b4743c66bdc6dcd735.diff
LOG: [mlir][sparse] add asserts on reading in tensor data
Rationale:
Since I made the argument that metadata helps with extra
verification checks, I better actually do that ;-)
Reviewed By: penpornk
Differential Revision: https://reviews.llvm.org/D95072
Added:
Modified:
mlir/lib/ExecutionEngine/SparseUtils.cpp
Removed:
################################################################################
diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index 376b989975b5..d1962661fe79 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -48,9 +48,9 @@ namespace {
/// and a rank-5 tensor element like
/// ({i,j,k,l,m}, a[i,j,k,l,m])
struct Element {
- Element(const std::vector<int64_t> &ind, double val)
+ Element(const std::vector<uint64_t> &ind, double val)
: indices(ind), value(val){};
- std::vector<int64_t> indices;
+ std::vector<uint64_t> indices;
double value;
};
@@ -61,9 +61,15 @@ struct Element {
/// formats require the elements to appear in lexicographic index order).
struct SparseTensor {
public:
- SparseTensor(int64_t capacity) : pos(0) { elements.reserve(capacity); }
+ SparseTensor(const std::vector<uint64_t> &szs, uint64_t capacity)
+ : sizes(szs), pos(0) {
+ elements.reserve(capacity);
+ }
// Add element as indices and value.
- void add(const std::vector<int64_t> &ind, double val) {
+ void add(const std::vector<uint64_t> &ind, double val) {
+ assert(sizes.size() == ind.size());
+ for (int64_t r = 0, rank = sizes.size(); r < rank; r++)
+ assert(ind[r] < sizes[r]); // within bounds
elements.emplace_back(Element(ind, val));
}
// Sort elements lexicographically by index.
@@ -82,6 +88,8 @@ struct SparseTensor {
}
return false;
}
+
+ std::vector<uint64_t> sizes; // per-rank dimension sizes
std::vector<Element> elements;
uint64_t pos;
};
@@ -225,20 +233,24 @@ extern "C" void *openTensorC(char *filename, uint64_t *idata) {
fprintf(stderr, "Unknown format %s\n", filename);
exit(1);
}
- // Read all nonzero elements.
+ // Prepare sparse tensor object with per-rank dimension sizes
+ // and the number of nonzeros as initial capacity.
uint64_t rank = idata[0];
uint64_t nnz = idata[1];
- SparseTensor *tensor = new SparseTensor(nnz);
- std::vector<int64_t> indices(rank);
- double value;
+ std::vector<uint64_t> indices(rank);
+ for (uint64_t r = 0; r < rank; r++)
+ indices[r] = idata[2 + r];
+ SparseTensor *tensor = new SparseTensor(indices, nnz);
+ // Read all nonzero elements.
for (uint64_t k = 0; k < nnz; k++) {
for (uint64_t r = 0; r < rank; r++) {
- if (fscanf(file, "%" PRId64, &indices[r]) != 1) {
+ if (fscanf(file, "%" PRIu64, &indices[r]) != 1) {
fprintf(stderr, "Cannot find next index in %s\n", filename);
exit(1);
}
indices[r]--; // 0-based index
}
+ double value;
if (fscanf(file, "%lg\n", &value) != 1) {
fprintf(stderr, "Cannot find next value in %s\n", filename);
exit(1);
More information about the llvm-branch-commits
mailing list