[Mlir-commits] [mlir] 46c8422 - [mlir][transform] change RaggedArray internals

Alex Zinenko llvmlistbot at llvm.org
Thu Mar 16 16:14:01 PDT 2023


Author: Alex Zinenko
Date: 2023-03-16T23:13:53Z
New Revision: 46c8422d833ba69a157199c26955a5faaada2927

URL: https://github.com/llvm/llvm-project/commit/46c8422d833ba69a157199c26955a5faaada2927
DIFF: https://github.com/llvm/llvm-project/commit/46c8422d833ba69a157199c26955a5faaada2927.diff

LOG: [mlir][transform] change RaggedArray internals

Change the internal storage scheme from storing a MutableArrayRef to
storing an explicit offset+length pair. Storing an ArrayRef is dangerous
because it contains the pointer to the first element in the range, but
the entire storage vector may be reallocated, making the pointer
dangling. We don't know when the reallocation happends, so we can't
update the ArrayRefs. Store the explicit offset instead and construct
ArrayRefs on-the-fly.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D146239

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h b/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h
index c3c38d61bef3e..0ee23914fa4e1 100644
--- a/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h
+++ b/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h
@@ -9,6 +9,7 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include <iterator>
 
 namespace mlir {
 /// A 2D array where each row may have 
diff erent length. Elements of each row
@@ -25,15 +26,92 @@ class RaggedArray {
 
   /// Accesses `pos`-th row.
   ArrayRef<T> operator[](size_t pos) const { return at(pos); }
-  ArrayRef<T> at(size_t pos) const { return slices[pos]; }
+  ArrayRef<T> at(size_t pos) const {
+    if (slices[pos].first == static_cast<size_t>(-1))
+      return ArrayRef<T>();
+    return ArrayRef<T>(storage).slice(slices[pos].first, slices[pos].second);
+  }
   MutableArrayRef<T> operator[](size_t pos) { return at(pos); }
-  MutableArrayRef<T> at(size_t pos) { return slices[pos]; }
+  MutableArrayRef<T> at(size_t pos) {
+    if (slices[pos].first == static_cast<size_t>(-1))
+      return MutableArrayRef<T>();
+    return MutableArrayRef<T>(storage).slice(slices[pos].first,
+                                             slices[pos].second);
+  }
+
+  /// Iterator over the rows.
+  class iterator
+      : public llvm::iterator_facade_base<
+            iterator, std::forward_iterator_tag, MutableArrayRef<T>,
+            std::ptr
diff _t, MutableArrayRef<T> *, MutableArrayRef<T>> {
+  public:
+    /// Creates the start iterator.
+    explicit iterator(RaggedArray &ragged) : ragged(ragged), pos(0) {}
+
+    /// Creates the end iterator.
+    iterator(RaggedArray &ragged, size_t pos) : ragged(ragged), pos(pos) {}
+
+    /// Dereferences the current iterator. Assumes in-bounds.
+    MutableArrayRef<T> operator*() const { return ragged[pos]; }
+
+    /// Increments the iterator.
+    iterator &operator++() {
+      if (pos < ragged.slices.size())
+        ++pos;
+      return *this;
+    }
+
+    /// Compares the two iterators. Iterators into 
diff erent ragged arrays
+    /// compare not equal.
+    bool operator==(const iterator &other) const {
+      return &ragged == &other.ragged && pos == other.pos;
+    }
+
+  private:
+    RaggedArray &ragged;
+    size_t pos;
+  };
+
+  /// Constant iterator over the rows.
+  class const_iterator
+      : public llvm::iterator_facade_base<
+            const_iterator, std::forward_iterator_tag, ArrayRef<T>,
+            std::ptr
diff _t, ArrayRef<T> *, ArrayRef<T>> {
+  public:
+    /// Creates the start iterator.
+    explicit const_iterator(const RaggedArray &ragged)
+        : ragged(ragged), pos(0) {}
+
+    /// Creates the end iterator.
+    const_iterator(const RaggedArray &ragged, size_t pos)
+        : ragged(ragged), pos(pos) {}
+
+    /// Dereferences the current iterator. Assumes in-bounds.
+    ArrayRef<T> operator*() const { return ragged[pos]; }
+
+    /// Increments the iterator.
+    const_iterator &operator++() {
+      if (pos < ragged.slices.size())
+        ++pos;
+      return *this;
+    }
+
+    /// Compares the two iterators. Iterators into 
diff erent ragged arrays
+    /// compare not equal.
+    bool operator==(const const_iterator &other) const {
+      return &ragged == &other.ragged && pos == other.pos;
+    }
+
+  private:
+    const RaggedArray &ragged;
+    size_t pos;
+  };
 
   /// Iterator over rows.
-  auto begin() { return slices.begin(); }
-  auto begin() const { return slices.begin(); }
-  auto end() { return slices.end(); }
-  auto end() const { return slices.end(); }
+  const_iterator begin() const { return const_iterator(*this); }
+  const_iterator end() const { return const_iterator(*this, slices.size()); }
+  iterator begin() { return iterator(*this); }
+  iterator end() { return iterator(*this, slices.size()); }
 
   /// Reserve space to store `size` rows with `nestedSize` elements each.
   void reserve(size_t size, size_t nestedSize = 0) {
@@ -53,38 +131,41 @@ class RaggedArray {
   /// succeeding rows.
   template <typename Range>
   void replace(size_t pos, Range &&elements) {
-    auto from = slices[pos].data();
-    if (from != nullptr) {
-      auto to = std::next(from, slices[pos].size());
+    if (slices[pos].first != static_cast<size_t>(-1)) {
+      auto from = std::next(storage.begin(), slices[pos].first);
+      auto to = std::next(from, slices[pos].second);
       auto newFrom = storage.erase(from, to);
       // Update the array refs after the underlying storage was shifted.
       for (size_t i = pos + 1, e = size(); i < e; ++i) {
-        slices[i] = MutableArrayRef<T>(newFrom, slices[i].size());
-        std::advance(newFrom, slices[i].size());
+        slices[i] = std::make_pair(std::distance(storage.begin(), newFrom),
+                                   slices[i].second);
+        std::advance(newFrom, slices[i].second);
       }
     }
     slices[pos] = appendToStorage(std::forward<Range>(elements));
   }
 
   /// Appends `num` empty rows to the array.
-  void appendEmptyRows(size_t num) { slices.resize(slices.size() + num); }
+  void appendEmptyRows(size_t num) {
+    slices.resize(slices.size() + num, std::pair<size_t, size_t>(-1, 0));
+  }
 
 private:
-  /// Appends the given elements to the storage and returns an ArrayRef pointing
-  /// to them in the storage.
+  /// Appends the given elements to the storage and returns an ArrayRef
+  /// pointing to them in the storage.
   template <typename Range>
-  MutableArrayRef<T> appendToStorage(Range &&elements) {
+  std::pair<size_t, size_t> appendToStorage(Range &&elements) {
     size_t start = storage.size();
     llvm::append_range(storage, std::forward<Range>(elements));
-    return MutableArrayRef<T>(storage).drop_front(start);
+    return std::make_pair(start, storage.size() - start);
   }
 
-  /// Outer elements of the ragged array. Each entry is a reference to a
-  /// contiguous segment in the `storage` list that contains the actual
-  /// elements. This allows for elements to be stored contiguously without
-  /// nested vectors and for 
diff erent segments to be set or replaced in any
-  /// order.
-  SmallVector<MutableArrayRef<T>> slices;
+  /// Outer elements of the ragged array. Each entry is an (offset, length)
+  /// pair identifying a contiguous segment in the `storage` list that
+  /// contains the actual elements. This allows for elements to be stored
+  /// contiguously without nested vectors and for 
diff erent segments to be set
+  /// or replaced in any order.
+  SmallVector<std::pair<size_t, size_t>> slices;
 
   /// Dense storage for ragged array elements.
   SmallVector<T> storage;


        


More information about the Mlir-commits mailing list