[Mlir-commits] [mlir] 2556d62 - [mlir][sparse] assert fail on mismatch between rank and annotations array
Aart Bik
llvmlistbot at llvm.org
Thu Feb 18 23:22:36 PST 2021
Author: Aart Bik
Date: 2021-02-18T23:22:14-08:00
New Revision: 2556d622828ae5631ac483d82592440fa1910d80
URL: https://github.com/llvm/llvm-project/commit/2556d622828ae5631ac483d82592440fa1910d80
DIFF: https://github.com/llvm/llvm-project/commit/2556d622828ae5631ac483d82592440fa1910d80.diff
LOG: [mlir][sparse] assert fail on mismatch between rank and annotations array
Rationale:
Providing the wrong number of sparse/dense annotations was silently
ignored or caused unrelated crashes. This minor change verifies that
the provided number matches the rank.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D97034
Added:
Modified:
mlir/lib/ExecutionEngine/SparseUtils.cpp
Removed:
################################################################################
diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index 0ff6f7d49b46..903b9f115182 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -76,8 +76,8 @@ struct SparseTensor {
}
/// Adds element as indices and value.
void add(const std::vector<uint64_t> &ind, double val) {
- assert(sizes.size() == ind.size());
- for (int64_t r = 0, rank = sizes.size(); r < rank; r++)
+ assert(getRank() == ind.size());
+ for (int64_t r = 0, rank = getRank(); r < rank; r++)
assert(ind[r] < sizes[r]); // within bounds
elements.emplace_back(Element(ind, val));
}
@@ -85,6 +85,8 @@ struct SparseTensor {
void sort() { std::sort(elements.begin(), elements.end(), lexOrder); }
/// Primitive one-time iteration.
const Element &next() { return elements[pos++]; }
+ /// Returns rank.
+ uint64_t getRank() const { return sizes.size(); }
/// Getter for sizes array.
const std::vector<uint64_t> &getSizes() const { return sizes; }
/// Getter for elements array.
@@ -139,13 +141,13 @@ class SparseTensorStorage : public SparseTensorStorageBase {
/// Constructs sparse tensor storage scheme following the given
/// per-rank dimension dense/sparse annotations.
SparseTensorStorage(SparseTensor *tensor, bool *sparsity)
- : sizes(tensor->getSizes()), pointers(sizes.size()),
- indices(sizes.size()) {
+ : sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) {
// Provide hints on capacity.
// TODO: needs fine-tuning based on sparsity
- values.reserve(tensor->getElements().size());
- for (uint64_t d = 0, s = 1, rank = sizes.size(); d < rank; d++) {
- s *= tensor->getSizes()[d];
+ uint64_t nnz = tensor->getElements().size();
+ values.reserve(nnz);
+ for (uint64_t d = 0, s = 1, rank = getRank(); d < rank; d++) {
+ s *= sizes[d];
if (sparsity[d]) {
pointers[d].reserve(s + 1);
indices[d].reserve(s);
@@ -153,12 +155,16 @@ class SparseTensorStorage : public SparseTensorStorageBase {
}
}
// Then setup the tensor.
- traverse(tensor, sparsity, 0, tensor->getElements().size(), 0);
+ traverse(tensor, sparsity, 0, nnz, 0);
}
virtual ~SparseTensorStorage() {}
+ uint64_t getRank() const { return sizes.size(); }
+
uint64_t getDimSize(uint64_t d) override { return sizes[d]; }
+
+ // Partially specialize these three methods based on template types.
void getPointers(std::vector<P> **out, uint64_t d) override {
*out = &pointers[d];
}
@@ -176,7 +182,7 @@ class SparseTensorStorage : public SparseTensorStorageBase {
uint64_t d) {
const std::vector<Element> &elements = tensor->getElements();
// Once dimensions are exhausted, insert the numerical values.
- if (d == sizes.size()) {
+ if (d == getRank()) {
values.push_back(lo < hi ? elements[lo].value : 0.0);
return;
}
@@ -221,9 +227,10 @@ class SparseTensorStorage : public SparseTensorStorageBase {
/// Templated reader.
template <typename P, typename I, typename V>
-void *newSparseTensor(char *filename, bool *sparsity) {
+void *newSparseTensor(char *filename, bool *sparsity, uint64_t size) {
uint64_t idata[64];
SparseTensor *t = static_cast<SparseTensor *>(openTensorC(filename, idata));
+ assert(size == t->getRank()); // sparsity array must match rank
SparseTensorStorageBase *tensor =
new SparseTensorStorage<P, I, V>(t, sparsity);
delete t;
@@ -481,21 +488,29 @@ void *newSparseTensor(char *filename, bool *abase, bool *adata, uint64_t aoff,
assert(astride == 1);
bool *sparsity = abase + aoff;
if (ptrTp == kU64 && indTp == kU64 && valTp == kF64)
- return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity);
+ return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity,
+ asize);
if (ptrTp == kU64 && indTp == kU64 && valTp == kF32)
- return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity);
+ return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity,
+ asize);
if (ptrTp == kU64 && indTp == kU32 && valTp == kF64)
- return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity);
+ return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity,
+ asize);
if (ptrTp == kU64 && indTp == kU32 && valTp == kF32)
- return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity);
+ return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity,
+ asize);
if (ptrTp == kU32 && indTp == kU64 && valTp == kF64)
- return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity);
+ return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity,
+ asize);
if (ptrTp == kU32 && indTp == kU64 && valTp == kF32)
- return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity);
+ return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity,
+ asize);
if (ptrTp == kU32 && indTp == kU32 && valTp == kF64)
- return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity);
+ return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity,
+ asize);
if (ptrTp == kU32 && indTp == kU32 && valTp == kF32)
- return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity);
+ return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity,
+ asize);
fputs("unsupported combination of types\n", stderr);
exit(1);
}
More information about the Mlir-commits
mailing list