[Mlir-commits] [mlir] [mlir][sparse] add expanded size to API (PR #68614)

Aart Bik llvmlistbot at llvm.org
Mon Oct 9 10:45:20 PDT 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/68614

Used for asserting we do not run out of bounds on the expanded access pattern.

>From 436337fab4eca9e55b80fba9f424d97e38718a0d Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Mon, 9 Oct 2023 10:40:31 -0700
Subject: [PATCH] [mlir][sparse] add expanded size to API

Used for asserting we do not run out of bounds on the
expanded access pattern.
---
 mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h | 8 ++++++--
 mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp        | 2 +-
 mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp         | 3 ++-
 3 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 0dd23ac52ac6790..0407bccaae8790c 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -214,8 +214,10 @@ class SparseTensorStorageBase {
   /// * `added` a map from `[0..count)` to last-level coordinates for
   ///   which `filled` is true and `values` contains the assotiated value.
   /// * `count` the size of `added`.
+  /// * `expsz` the size of the expanded vector (verification only).
 #define DECL_EXPINSERT(VNAME, V)                                               \
-  virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t);
+  virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t,        \
+                         uint64_t);
   MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT)
 #undef DECL_EXPINSERT
 
@@ -426,7 +428,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
 
   /// Partially specialize expanded insertions based on template types.
   void expInsert(uint64_t *lvlCoords, V *values, bool *filled, uint64_t *added,
-                 uint64_t count) final {
+                 uint64_t count, uint64_t expsz) final {
     assert((lvlCoords && values && filled && added) && "Received nullptr");
     if (count == 0)
       return;
@@ -435,6 +437,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     // Restore insertion path for first insert.
     const uint64_t lastLvl = getLvlRank() - 1;
     uint64_t c = added[0];
+    assert(c <= expsz);
     assert(filled[c] && "added coordinate is not filled");
     lvlCoords[lastLvl] = c;
     lexInsert(lvlCoords, values[c]);
@@ -444,6 +447,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     for (uint64_t i = 1; i < count; ++i) {
       assert(c < added[i] && "non-lexicographic insertion");
       c = added[i];
+      assert(c <= expsz);
       assert(filled[c] && "added coordinate is not filled");
       lvlCoords[lastLvl] = c;
       insPath(lvlCoords, lastLvl, added[i - 1] + 1, values[c]);
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 77861e074e9333b..199e4205a61a25b 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -90,7 +90,7 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
 
 #define IMPL_EXPINSERT(VNAME, V)                                               \
   void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \
-                                          uint64_t) {                          \
+                                          uint64_t, uint64_t) {                \
     FATAL_PIV("expInsert" #VNAME);                                             \
   }
 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 05da8cd79190ed0..8340fe7dcf925be 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -480,7 +480,8 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
     V *values = MEMREF_GET_PAYLOAD(vref);                                      \
     bool *filled = MEMREF_GET_PAYLOAD(fref);                                   \
     index_type *added = MEMREF_GET_PAYLOAD(aref);                              \
-    tensor.expInsert(lvlCoords, values, filled, added, count);                 \
+    uint64_t expsz = vref->sizes[0];                                           \
+    tensor.expInsert(lvlCoords, values, filled, added, count, expsz);          \
   }
 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
 #undef IMPL_EXPINSERT



More information about the Mlir-commits mailing list