[Mlir-commits] [mlir] ab6334d - [mlir][sparse] add expanded size to API (#68614)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 9 14:42:15 PDT 2023
Author: Aart Bik
Date: 2023-10-09T14:42:11-07:00
New Revision: ab6334dd11d2679ffca877a1e444efb40221cfe1
URL: https://github.com/llvm/llvm-project/commit/ab6334dd11d2679ffca877a1e444efb40221cfe1
DIFF: https://github.com/llvm/llvm-project/commit/ab6334dd11d2679ffca877a1e444efb40221cfe1.diff
LOG: [mlir][sparse] add expanded size to API (#68614)
Used for asserting we do not run out of bounds on the expanded access
pattern.
Added:
Modified:
mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 0dd23ac52ac6790..0407bccaae8790c 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -214,8 +214,10 @@ class SparseTensorStorageBase {
/// * `added` a map from `[0..count)` to last-level coordinates for
/// which `filled` is true and `values` contains the assotiated value.
/// * `count` the size of `added`.
+ /// * `expsz` the size of the expanded vector (verification only).
#define DECL_EXPINSERT(VNAME, V) \
- virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t);
+ virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t, \
+ uint64_t);
MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT)
#undef DECL_EXPINSERT
@@ -426,7 +428,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
/// Partially specialize expanded insertions based on template types.
void expInsert(uint64_t *lvlCoords, V *values, bool *filled, uint64_t *added,
- uint64_t count) final {
+ uint64_t count, uint64_t expsz) final {
assert((lvlCoords && values && filled && added) && "Received nullptr");
if (count == 0)
return;
@@ -435,6 +437,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
// Restore insertion path for first insert.
const uint64_t lastLvl = getLvlRank() - 1;
uint64_t c = added[0];
+ assert(c <= expsz);
assert(filled[c] && "added coordinate is not filled");
lvlCoords[lastLvl] = c;
lexInsert(lvlCoords, values[c]);
@@ -444,6 +447,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
for (uint64_t i = 1; i < count; ++i) {
assert(c < added[i] && "non-lexicographic insertion");
c = added[i];
+ assert(c <= expsz);
assert(filled[c] && "added coordinate is not filled");
lvlCoords[lastLvl] = c;
insPath(lvlCoords, lastLvl, added[i - 1] + 1, values[c]);
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 77861e074e9333b..199e4205a61a25b 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -90,7 +90,7 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
#define IMPL_EXPINSERT(VNAME, V) \
void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \
- uint64_t) { \
+ uint64_t, uint64_t) { \
FATAL_PIV("expInsert" #VNAME); \
}
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 05da8cd79190ed0..8340fe7dcf925be 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -480,7 +480,8 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
V *values = MEMREF_GET_PAYLOAD(vref); \
bool *filled = MEMREF_GET_PAYLOAD(fref); \
index_type *added = MEMREF_GET_PAYLOAD(aref); \
- tensor.expInsert(lvlCoords, values, filled, added, count); \
+ uint64_t expsz = vref->sizes[0]; \
+ tensor.expInsert(lvlCoords, values, filled, added, count, expsz); \
}
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
#undef IMPL_EXPINSERT
More information about the Mlir-commits
mailing list