[clang-tools-extra] b37dafd - [pseudo] Store shift and goto actions in a compact structure with faster lookup.

Sam McCall via cfe-commits cfe-commits at lists.llvm.org
Mon Jul 4 10:40:16 PDT 2022


Author: Sam McCall
Date: 2022-07-04T19:40:04+02:00
New Revision: b37dafd5dc83a5f1fc4ca7e37e4944364ff9d5b7

URL: https://github.com/llvm/llvm-project/commit/b37dafd5dc83a5f1fc4ca7e37e4944364ff9d5b7
DIFF: https://github.com/llvm/llvm-project/commit/b37dafd5dc83a5f1fc4ca7e37e4944364ff9d5b7.diff

LOG: [pseudo] Store shift and goto actions in a compact structure with faster lookup.

The actions table is very compact but the binary search to find the
correct action is relatively expensive.
A hashtable is faster but pretty large (64 bits per value, plus empty
slots, and lookup is constant time but not trivial due to collisions).

The structure in this patch uses 1.25 bits per entry (whether present or absent)
plus the size of the values, and lookup is trivial.

The Shift table is 119KB = 27KB values + 92KB keys.
The Goto table is 86KB = 30KB values + 57KB keys.
(Goto has a smaller keyspace as #nonterminals < #terminals, and more entries).

This patch improves glrParse speed by 28%: 4.69 => 5.99 MB/s
Overall the table grows by 60%: 142 => 228KB.

By comparison, DenseMap<unsigned, StateID> is "only" 16% faster (5.43 MB/s),
and results in a 285% larger table (547 KB) vs the baseline.

Differential Revision: https://reviews.llvm.org/D128485

Added: 
    

Modified: 
    clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
    clang-tools-extra/pseudo/lib/GLR.cpp
    clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
    clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
    clang-tools-extra/pseudo/unittests/LRTableTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h b/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
index d480956704960..cd183b552d6f9 100644
--- a/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
+++ b/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
@@ -40,6 +40,8 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/Support/Capacity.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
 #include <cstdint>
 #include <vector>
 
@@ -123,12 +125,18 @@ class LRTable {
 
   // Returns the state after we reduce a nonterminal.
   // Expected to be called by LR parsers.
-  // REQUIRES: Nonterminal is valid here.
-  StateID getGoToState(StateID State, SymbolID Nonterminal) const;
+  // If the nonterminal is invalid here, returns None.
+  llvm::Optional<StateID> getGoToState(StateID State,
+                                       SymbolID Nonterminal) const {
+    return Gotos.get(gotoIndex(State, Nonterminal, numStates()));
+  }
   // Returns the state after we shift a terminal.
   // Expected to be called by LR parsers.
   // If the terminal is invalid here, returns None.
-  llvm::Optional<StateID> getShiftState(StateID State, SymbolID Terminal) const;
+  llvm::Optional<StateID> getShiftState(StateID State,
+                                        SymbolID Terminal) const {
+    return Shifts.get(shiftIndex(State, Terminal, numStates()));
+  }
 
   // Returns the possible reductions from a state.
   //
@@ -164,9 +172,7 @@ class LRTable {
   StateID getStartState(SymbolID StartSymbol) const;
 
   size_t bytes() const {
-    return sizeof(*this) + llvm::capacity_in_bytes(Actions) +
-           llvm::capacity_in_bytes(Symbols) +
-           llvm::capacity_in_bytes(StateOffset) +
+    return sizeof(*this) + Gotos.bytes() + Shifts.bytes() +
            llvm::capacity_in_bytes(Reduces) +
            llvm::capacity_in_bytes(ReduceOffset) +
            llvm::capacity_in_bytes(FollowSets);
@@ -194,22 +200,92 @@ class LRTable {
                                llvm::ArrayRef<ReduceEntry>);
 
 private:
-  // Looks up actions stored in the generic table.
-  llvm::ArrayRef<Action> find(StateID State, SymbolID Symbol) const;
-
-  // Conceptually the LR table is a multimap from (State, SymbolID) => Action.
-  // Our physical representation is quite 
diff erent for compactness.
-
-  // Index is StateID, value is the offset into Symbols/Actions
-  // where the entries for this state begin.
-  // Give a state id, the corresponding half-open range of Symbols/Actions is
-  // [StateOffset[id], StateOffset[id+1]).
-  std::vector<uint32_t> StateOffset;
-  // Parallel to Actions, the value is SymbolID (columns of the matrix).
-  // Grouped by the StateID, and only subranges are sorted.
-  std::vector<SymbolID> Symbols;
-  // A flat list of available actions, sorted by (State, SymbolID).
-  std::vector<Action> Actions;
+  unsigned numStates() const { return ReduceOffset.size() - 1; }
+
+  // A map from unsigned key => StateID, used to store actions.
+  // The keys should be sequential but the values are somewhat sparse.
+  //
+  // In practice, the keys encode (origin state, symbol) pairs, and the values
+  // are the state we should move to after seeing that symbol.
+  //
+  // We store one bit for presence/absence of the value for each key.
+  // At every 64th key, we store the offset into the table of values.
+  //   e.g. key 0x500 is checkpoint 0x500/64 = 20
+  //                     Checkpoints[20] = 34
+  //        get(0x500) = Values[34]                (assuming it has a value)
+  // To look up values in between, we count the set bits:
+  //        get(0x509) has a value if HasValue[20] & (1<<9)
+  //        #values between 0x500 and 0x509: popcnt(HasValue[20] & (1<<9 - 1))
+  //        get(0x509) = Values[34 + popcnt(...)]
+  //
+  // Overall size is 1.25 bits/key + 16 bits/value.
+  // Lookup is constant time with a low factor (no hashing).
+  class TransitionTable {
+    using Word = uint64_t;
+    constexpr static unsigned WordBits = CHAR_BIT * sizeof(Word);
+
+    std::vector<StateID> Values;
+    std::vector<Word> HasValue;
+    std::vector<uint16_t> Checkpoints;
+
+  public:
+    TransitionTable() = default;
+    TransitionTable(const llvm::DenseMap<unsigned, StateID> &Entries,
+                    unsigned NumKeys) {
+      assert(
+          Entries.size() <
+              std::numeric_limits<decltype(Checkpoints)::value_type>::max() &&
+          "16 bits too small for value offsets!");
+      unsigned NumWords = (NumKeys + WordBits - 1) / WordBits;
+      HasValue.resize(NumWords, 0);
+      Checkpoints.reserve(NumWords);
+      Values.reserve(Entries.size());
+      for (unsigned I = 0; I < NumKeys; ++I) {
+        if ((I % WordBits) == 0)
+          Checkpoints.push_back(Values.size());
+        auto It = Entries.find(I);
+        if (It != Entries.end()) {
+          HasValue[I / WordBits] |= (Word(1) << (I % WordBits));
+          Values.push_back(It->second);
+        }
+      }
+    }
+
+    llvm::Optional<StateID> get(unsigned Key) const {
+      // Do we have a value for this key?
+      Word KeyMask = Word(1) << (Key % WordBits);
+      unsigned KeyWord = Key / WordBits;
+      if ((HasValue[KeyWord] & KeyMask) == 0)
+        return llvm::None;
+      // Count the number of values since the checkpoint.
+      Word BelowKeyMask = KeyMask - 1;
+      unsigned CountSinceCheckpoint =
+          llvm::countPopulation(HasValue[KeyWord] & BelowKeyMask);
+      // Find the value relative to the last checkpoint.
+      return Values[Checkpoints[KeyWord] + CountSinceCheckpoint];
+    }
+
+    unsigned size() const { return Values.size(); }
+
+    size_t bytes() const {
+      return llvm::capacity_in_bytes(HasValue) +
+             llvm::capacity_in_bytes(Values) +
+             llvm::capacity_in_bytes(Checkpoints);
+    }
+  };
+  // Shift and Goto tables are keyed by encoded (State, Symbol).
+  static unsigned shiftIndex(StateID State, SymbolID Terminal,
+                             unsigned NumStates) {
+    return NumStates * symbolToToken(Terminal) + State;
+  }
+  static unsigned gotoIndex(StateID State, SymbolID Nonterminal,
+                            unsigned NumStates) {
+    assert(isNonterminal(Nonterminal));
+    return NumStates * Nonterminal + State;
+  }
+  TransitionTable Shifts;
+  TransitionTable Gotos;
+
   // A sorted table, storing the start state for each target parsing symbol.
   std::vector<std::pair<SymbolID, StateID>> StartStates;
 

diff  --git a/clang-tools-extra/pseudo/lib/GLR.cpp b/clang-tools-extra/pseudo/lib/GLR.cpp
index 0b9cf46245a96..6373024b3db65 100644
--- a/clang-tools-extra/pseudo/lib/GLR.cpp
+++ b/clang-tools-extra/pseudo/lib/GLR.cpp
@@ -318,9 +318,11 @@ class GLRReduce {
     do {
       const PushSpec &Push = Sequences.top().second;
       FamilySequences.emplace_back(Sequences.top().first.Rule, *Push.Seq);
-      for (const GSS::Node *Base : Push.LastPop->parents())
-        FamilyBases.emplace_back(
-            Params.Table.getGoToState(Base->State, F.Symbol), Base);
+      for (const GSS::Node *Base : Push.LastPop->parents()) {
+        auto NextState = Params.Table.getGoToState(Base->State, F.Symbol);
+        assert(NextState.hasValue() && "goto must succeed after reduce!");
+        FamilyBases.emplace_back(*NextState, Base);
+      }
 
       Sequences.pop();
     } while (!Sequences.empty() && Sequences.top().first == F);
@@ -393,8 +395,9 @@ class GLRReduce {
     }
     const ForestNode *Parsed =
         &Params.Forest.createSequence(Rule.Target, *RID, TempSequence);
-    StateID NextState = Params.Table.getGoToState(Base->State, Rule.Target);
-    Heads->push_back(Params.GSStack.addNode(NextState, Parsed, {Base}));
+    auto NextState = Params.Table.getGoToState(Base->State, Rule.Target);
+    assert(NextState.hasValue() && "goto must succeed after reduce!");
+    Heads->push_back(Params.GSStack.addNode(*NextState, Parsed, {Base}));
     return true;
   }
 };
@@ -444,7 +447,8 @@ const ForestNode &glrParse(const TokenStream &Tokens, const ParseParams &Params,
   }
   LLVM_DEBUG(llvm::dbgs() << llvm::formatv("Reached eof\n"));
 
-  StateID AcceptState = Params.Table.getGoToState(StartState, StartSymbol);
+  auto AcceptState = Params.Table.getGoToState(StartState, StartSymbol);
+  assert(AcceptState.hasValue() && "goto must succeed after start symbol!");
   const ForestNode *Result = nullptr;
   for (const auto *Head : Heads) {
     if (Head->State == AcceptState) {

diff  --git a/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp b/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
index 3b35d232b8a44..058970be5d60e 100644
--- a/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
+++ b/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
@@ -34,11 +34,10 @@ std::string LRTable::dumpStatistics() const {
   return llvm::formatv(R"(
 Statistics of the LR parsing table:
     number of states: {0}
-    number of actions: {1}
-    number of reduces: {2}
-    size of the table (bytes): {3}
+    number of actions: shift={1} goto={2} reduce={3}
+    size of the table (bytes): {4}
 )",
-                       StateOffset.size() - 1, Actions.size(), Reduces.size(),
+                       numStates(), Shifts.size(), Gotos.size(), Reduces.size(),
                        bytes())
       .str();
 }
@@ -47,15 +46,13 @@ std::string LRTable::dumpForTests(const Grammar &G) const {
   std::string Result;
   llvm::raw_string_ostream OS(Result);
   OS << "LRTable:\n";
-  for (StateID S = 0; S < StateOffset.size() - 1; ++S) {
+  for (StateID S = 0; S < numStates(); ++S) {
     OS << llvm::formatv("State {0}\n", S);
     for (uint16_t Terminal = 0; Terminal < NumTerminals; ++Terminal) {
       SymbolID TokID = tokenSymbol(static_cast<tok::TokenKind>(Terminal));
-      for (auto A : find(S, TokID)) {
-        if (A.kind() == LRTable::Action::Shift)
-          OS.indent(4) << llvm::formatv("{0}: shift state {1}\n",
-                                        G.symbolName(TokID), A.getShiftState());
-      }
+      if (auto SS = getShiftState(S, TokID))
+        OS.indent(4) << llvm::formatv("{0}: shift state {1}\n",
+                                      G.symbolName(TokID), SS);
     }
     for (RuleID R : getReduceRules(S)) {
       SymbolID Target = G.lookupRule(R).Target;
@@ -71,55 +68,15 @@ std::string LRTable::dumpForTests(const Grammar &G) const {
     }
     for (SymbolID NontermID = 0; NontermID < G.table().Nonterminals.size();
          ++NontermID) {
-      if (find(S, NontermID).empty())
-        continue;
-      OS.indent(4) << llvm::formatv("{0}: go to state {1}\n",
-                                    G.symbolName(NontermID),
-                                    getGoToState(S, NontermID));
+      if (auto GS = getGoToState(S, NontermID)) {
+        OS.indent(4) << llvm::formatv("{0}: go to state {1}\n",
+                                      G.symbolName(NontermID), *GS);
+      }
     }
   }
   return OS.str();
 }
 
-llvm::Optional<LRTable::StateID>
-LRTable::getShiftState(StateID State, SymbolID Terminal) const {
-  // FIXME: we spend a significant amount of time on misses here.
-  // We could consider storing a std::bitset for a cheaper test?
-  assert(pseudo::isToken(Terminal) && "expected terminal symbol!");
-  for (const auto &Result : find(State, Terminal))
-    if (Result.kind() == Action::Shift)
-      return Result.getShiftState(); // unique: no shift/shift conflicts.
-  return llvm::None;
-}
-
-LRTable::StateID LRTable::getGoToState(StateID State,
-                                       SymbolID Nonterminal) const {
-  assert(pseudo::isNonterminal(Nonterminal) && "expected nonterminal symbol!");
-  auto Result = find(State, Nonterminal);
-  assert(Result.size() == 1 && Result.front().kind() == Action::GoTo);
-  return Result.front().getGoToState();
-}
-
-llvm::ArrayRef<LRTable::Action> LRTable::find(StateID Src, SymbolID ID) const {
-  assert(Src + 1u < StateOffset.size());
-  std::pair<size_t, size_t> Range =
-      std::make_pair(StateOffset[Src], StateOffset[Src + 1]);
-  auto SymbolRange = llvm::makeArrayRef(Symbols.data() + Range.first,
-                                        Symbols.data() + Range.second);
-
-  assert(llvm::is_sorted(SymbolRange) &&
-         "subrange of the Symbols should be sorted!");
-  const LRTable::StateID *Start =
-      llvm::partition_point(SymbolRange, [&ID](SymbolID S) { return S < ID; });
-  if (Start == SymbolRange.end())
-    return {};
-  const LRTable::StateID *End = Start;
-  while (End != SymbolRange.end() && *End == ID)
-    ++End;
-  return llvm::makeArrayRef(&Actions[Start - Symbols.data()],
-                            /*length=*/End - Start);
-}
-
 LRTable::StateID LRTable::getStartState(SymbolID Target) const {
   assert(llvm::is_sorted(StartStates) && "StartStates must be sorted!");
   auto It = llvm::partition_point(

diff  --git a/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp b/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
index 59ea4ce5e3276..77eabfdd39a3f 100644
--- a/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
+++ b/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
@@ -45,49 +45,24 @@ struct LRTable::Builder {
   llvm::DenseMap<StateID, llvm::SmallSet<RuleID, 4>> Reduces;
   std::vector<llvm::DenseSet<SymbolID>> FollowSets;
 
-  LRTable build(unsigned NumStates) && {
-    // E.g. given the following parsing table with 3 states and 3 terminals:
-    //
-    //            a    b     c
-    // +-------+----+-------+-+
-    // |state0 |    | s0,r0 | |
-    // |state1 | acc|       | |
-    // |state2 |    |  r1   | |
-    // +-------+----+-------+-+
-    //
-    // The final LRTable:
-    //  - StateOffset: [s0] = 0, [s1] = 2, [s2] = 3, [sentinel] = 4
-    //  - Symbols:     [ b,   b,   a,  b]
-    //    Actions:     [ s0, r0, acc, r1]
-    //                   ~~~~~~ range for state 0
-    //                           ~~~~ range for state 1
-    //                                ~~ range for state 2
-    // First step, we sort all entries by (State, Symbol, Action).
-    std::vector<Entry> Sorted(Entries.begin(), Entries.end());
-    llvm::sort(Sorted, [](const Entry &L, const Entry &R) {
-      return std::forward_as_tuple(L.State, L.Symbol, L.Act.opaque()) <
-             std::forward_as_tuple(R.State, R.Symbol, R.Act.opaque());
-    });
-
+  LRTable build(unsigned NumStates, unsigned NumNonterminals) && {
     LRTable Table;
-    Table.Actions.reserve(Sorted.size());
-    Table.Symbols.reserve(Sorted.size());
-    // We are good to finalize the States and Actions.
-    for (const auto &E : Sorted) {
-      Table.Actions.push_back(E.Act);
-      Table.Symbols.push_back(E.Symbol);
-    }
-    // Initialize the terminal and nonterminal offset, all ranges are empty by
-    // default.
-    Table.StateOffset = std::vector<uint32_t>(NumStates + 1, 0);
-    size_t SortedIndex = 0;
-    for (StateID State = 0; State < Table.StateOffset.size(); ++State) {
-      Table.StateOffset[State] = SortedIndex;
-      while (SortedIndex < Sorted.size() && Sorted[SortedIndex].State == State)
-        ++SortedIndex;
-    }
     Table.StartStates = std::move(StartStates);
 
+    // Compile the goto and shift actions into transition tables.
+    llvm::DenseMap<unsigned, SymbolID> Gotos;
+    llvm::DenseMap<unsigned, SymbolID> Shifts;
+    for (const auto &E : Entries) {
+      if (E.Act.kind() == Action::Shift)
+        Shifts.try_emplace(shiftIndex(E.State, E.Symbol, NumStates),
+                           E.Act.getShiftState());
+      else if (E.Act.kind() == Action::GoTo)
+        Gotos.try_emplace(gotoIndex(E.State, E.Symbol, NumStates),
+                          E.Act.getGoToState());
+    }
+    Table.Shifts = TransitionTable(Shifts, NumStates * NumTerminals);
+    Table.Gotos = TransitionTable(Gotos, NumStates * NumNonterminals);
+
     // Compile the follow sets into a bitmap.
     Table.FollowSets.resize(tok::NUM_TOKENS * FollowSets.size());
     for (SymbolID NT = 0; NT < FollowSets.size(); ++NT)
@@ -128,7 +103,8 @@ LRTable LRTable::buildForTests(const Grammar &G, llvm::ArrayRef<Entry> Entries,
   for (const ReduceEntry &E : Reduces)
     Build.Reduces[E.State].insert(E.Rule);
   Build.FollowSets = followSets(G);
-  return std::move(Build).build(/*NumStates=*/MaxState + 1);
+  return std::move(Build).build(/*NumStates=*/MaxState + 1,
+                                G.table().Nonterminals.size());
 }
 
 LRTable LRTable::buildSLR(const Grammar &G) {
@@ -156,7 +132,8 @@ LRTable LRTable::buildSLR(const Grammar &G) {
         Build.Reduces[SID].insert(I.rule());
     }
   }
-  return std::move(Build).build(Graph.states().size());
+  return std::move(Build).build(Graph.states().size(),
+                                G.table().Nonterminals.size());
 }
 
 } // namespace pseudo

diff  --git a/clang-tools-extra/pseudo/unittests/LRTableTest.cpp b/clang-tools-extra/pseudo/unittests/LRTableTest.cpp
index f317bcbd81517..061bf7ddd5f4a 100644
--- a/clang-tools-extra/pseudo/unittests/LRTableTest.cpp
+++ b/clang-tools-extra/pseudo/unittests/LRTableTest.cpp
@@ -60,7 +60,7 @@ TEST(LRTable, Builder) {
 
   EXPECT_EQ(T.getShiftState(1, Eof), llvm::None);
   EXPECT_EQ(T.getShiftState(1, Identifier), llvm::None);
-  EXPECT_EQ(T.getGoToState(1, Term), 3);
+  EXPECT_THAT(T.getGoToState(1, Term), ValueIs(3));
   EXPECT_THAT(T.getReduceRules(1), ElementsAre(2));
 
   // Verify the behaivor for other non-available-actions terminals.


        


More information about the cfe-commits mailing list