[clang-tools-extra] 85eaecb - [pseudo] Check follow-sets instead of tying reduce actions to lookahead tokens.

Sam McCall via cfe-commits cfe-commits at lists.llvm.org
Mon Jun 27 15:36:25 PDT 2022


Author: Sam McCall
Date: 2022-06-28T00:36:16+02:00
New Revision: 85eaecbe8e541924b6f87dd83f169056e74ce237

URL: https://github.com/llvm/llvm-project/commit/85eaecbe8e541924b6f87dd83f169056e74ce237
DIFF: https://github.com/llvm/llvm-project/commit/85eaecbe8e541924b6f87dd83f169056e74ce237.diff

LOG: [pseudo] Check follow-sets instead of tying reduce actions to lookahead tokens.

Previously, the action table stores a reduce action for each lookahead
token it should allow. These tokens are the followSet(action.rule.target).

In practice, the follow sets are large, so we spend a bunch of time binary
searching around all these essentially-duplicates to check whether our lookahead
token is there.
However the number of reduces for a given state is very small, so we're
much better off linear scanning over them and performing a fast check for each.

D128318 was an attempt at this, storing a bitmap for each reduce.
However it's even more compact just to use the follow sets directly, as
there are fewer nonterminals than (state, rule) pairs. It's also faster.

This specialized approach means unbundling Reduce from other actions in
LRTable, so it's no longer useful to support it in Action. I suspect
Action will soon go away, as we store each kind of action separately.

This improves glrParse speed by 42% (3.30 -> 4.69 MB/s).
It also reduces LR table size by 59% (343 -> 142kB).

Differential Revision: https://reviews.llvm.org/D128472

Added: 
    

Modified: 
    clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
    clang-tools-extra/pseudo/lib/GLR.cpp
    clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
    clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
    clang-tools-extra/pseudo/test/lr-build-basic.test
    clang-tools-extra/pseudo/test/lr-build-conflicts.test
    clang-tools-extra/pseudo/unittests/GLRTest.cpp
    clang-tools-extra/pseudo/unittests/LRTableTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h b/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
index ab619774d93da..70ce52924f110 100644
--- a/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
+++ b/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
@@ -38,6 +38,8 @@
 
 #include "clang-pseudo/grammar/Grammar.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/Support/Capacity.h"
 #include <cstdint>
 #include <vector>
 
@@ -62,6 +64,9 @@ class LRTable {
 
   // Action represents the terminal and nonterminal actions, it combines the
   // entry of the ACTION and GOTO tables from the LR literature.
+  //
+  // FIXME: as we move away from a homogeneous table structure shared between
+  // action types, this class becomes less useful. Remove it.
   class Action {
   public:
     enum Kind : uint8_t {
@@ -73,8 +78,6 @@ class LRTable {
       // A shift is a forward transition, and the value n is the next state that
       // the parser is to enter.
       Shift,
-      // Reduce by a rule: pop the state stack.
-      Reduce,
 
       // NOTE: there are no typical accept actions in the LRtable, accept
       // actions are handled specifically in the parser -- if the parser
@@ -91,7 +94,6 @@ class LRTable {
 
     static Action goTo(StateID S) { return Action(GoTo, S); }
     static Action shift(StateID S) { return Action(Shift, S); }
-    static Action reduce(RuleID RID) { return Action(Reduce, RID); }
     static Action sentinel() { return Action(Sentinel, 0); }
 
     StateID getShiftState() const {
@@ -102,10 +104,6 @@ class LRTable {
       assert(kind() == GoTo);
       return Value;
     }
-    RuleID getReduceRule() const {
-      assert(kind() == Reduce);
-      return Value;
-    }
     Kind kind() const { return static_cast<Kind>(K); }
 
     bool operator==(const Action &L) const { return opaque() == L.opaque(); }
@@ -123,9 +121,6 @@ class LRTable {
     uint16_t Value : ValueBits;
   };
 
-  // Returns all available actions for the given state on a terminal.
-  // Expected to be called by LR parsers.
-  llvm::ArrayRef<Action> getActions(StateID State, SymbolID Terminal) const;
   // Returns the state after we reduce a nonterminal.
   // Expected to be called by LR parsers.
   // REQUIRES: Nonterminal is valid here.
@@ -135,9 +130,26 @@ class LRTable {
   // If the terminal is invalid here, returns None.
   llvm::Optional<StateID> getShiftState(StateID State, SymbolID Terminal) const;
 
-  // Looks up available actions.
-  // Returns empty if no available actions in the table.
-  llvm::ArrayRef<Action> find(StateID State, SymbolID Symbol) const;
+  // Returns the possible reductions from a state.
+  //
+  // These are not keyed by a lookahead token. Instead, call canFollow() to
+  // check whether a reduction should apply in the current context:
+  //   for (RuleID R : LR.getReduceRules(S)) {
+  //     if (!LR.canFollow(G.lookupRule(R).Target, NextToken))
+  //       continue;
+  //     // ...apply reduce...
+  //   }
+  llvm::ArrayRef<RuleID> getReduceRules(StateID State) const {
+    return llvm::makeArrayRef(&Reduces[ReduceOffset[State]],
+                              &Reduces[ReduceOffset[State + 1]]);
+  }
+  // Returns whether Terminal can follow Nonterminal in a valid source file.
+  bool canFollow(SymbolID Nonterminal, SymbolID Terminal) const {
+    assert(isToken(Terminal));
+    assert(isNonterminal(Nonterminal));
+    return FollowSets.test(tok::NUM_TOKENS * Nonterminal +
+                           symbolToToken(Terminal));
+  }
 
   // Returns the state from which the LR parser should start to parse the input
   // tokens as the given StartSymbol.
@@ -151,9 +163,12 @@ class LRTable {
   StateID getStartState(SymbolID StartSymbol) const;
 
   size_t bytes() const {
-    return sizeof(*this) + Actions.capacity() * sizeof(Action) +
-           Symbols.capacity() * sizeof(SymbolID) +
-           StateOffset.capacity() * sizeof(uint32_t);
+    return sizeof(*this) + llvm::capacity_in_bytes(Actions) +
+           llvm::capacity_in_bytes(Symbols) +
+           llvm::capacity_in_bytes(StateOffset) +
+           llvm::capacity_in_bytes(Reduces) +
+           llvm::capacity_in_bytes(ReduceOffset) +
+           llvm::capacity_in_bytes(FollowSets);
   }
 
   std::string dumpStatistics() const;
@@ -162,17 +177,25 @@ class LRTable {
   // Build a SLR(1) parsing table.
   static LRTable buildSLR(const Grammar &G);
 
-  class Builder;
+  struct Builder;
   // Represents an entry in the table, used for building the LRTable.
   struct Entry {
     StateID State;
     SymbolID Symbol;
     Action Act;
   };
+  struct ReduceEntry {
+    StateID State;
+    RuleID Rule;
+  };
   // Build a specifid table for testing purposes.
-  static LRTable buildForTests(const GrammarTable &, llvm::ArrayRef<Entry>);
+  static LRTable buildForTests(const Grammar &G, llvm::ArrayRef<Entry>,
+                               llvm::ArrayRef<ReduceEntry>);
 
 private:
+  // Looks up actions stored in the generic table.
+  llvm::ArrayRef<Action> find(StateID State, SymbolID Symbol) const;
+
   // Conceptually the LR table is a multimap from (State, SymbolID) => Action.
   // Our physical representation is quite 
diff erent for compactness.
 
@@ -188,6 +211,17 @@ class LRTable {
   std::vector<Action> Actions;
   // A sorted table, storing the start state for each target parsing symbol.
   std::vector<std::pair<SymbolID, StateID>> StartStates;
+
+  // Given a state ID S, the half-open range of Reduces is
+  // [ReduceOffset[S], ReduceOffset[S+1])
+  std::vector<uint32_t> ReduceOffset;
+  std::vector<RuleID> Reduces;
+  // Conceptually this is a bool[SymbolID][Token], each entry describing whether
+  // the grammar allows the (nonterminal) symbol to be followed by the token.
+  //
+  // This is flattened by encoding the (SymbolID Nonterminal, tok::Kind Token)
+  // as an index: Nonterminal * NUM_TOKENS + Token.
+  llvm::BitVector FollowSets;
 };
 llvm::raw_ostream &operator<<(llvm::raw_ostream &, const LRTable::Action &);
 

diff  --git a/clang-tools-extra/pseudo/lib/GLR.cpp b/clang-tools-extra/pseudo/lib/GLR.cpp
index d93f682afac6c..0b9cf46245a96 100644
--- a/clang-tools-extra/pseudo/lib/GLR.cpp
+++ b/clang-tools-extra/pseudo/lib/GLR.cpp
@@ -251,9 +251,8 @@ class GLRReduce {
 private:
   // pop walks up the parent chain(s) for a reduction from Head by to Rule.
   // Once we reach the end, record the bases and sequences.
-  void pop(const GSS::Node *Head, RuleID RID) {
+  void pop(const GSS::Node *Head, RuleID RID, const Rule &Rule) {
     LLVM_DEBUG(llvm::dbgs() << "  Pop " << Params.G.dumpRule(RID) << "\n");
-    const auto &Rule = Params.G.lookupRule(RID);
     Family F{/*Start=*/0, /*Symbol=*/Rule.Target, /*Rule=*/RID};
     TempSequence.resize_for_overwrite(Rule.Size);
     auto DFS = [&](const GSS::Node *N, unsigned I, auto &DFS) {
@@ -286,11 +285,11 @@ class GLRReduce {
       // In trivial cases, we perform the complete reduce here!
       if (popAndPushTrivial())
         continue;
-      for (const auto &A :
-           Params.Table.getActions((*Heads)[NextPopHead]->State, Lookahead)) {
-        if (A.kind() != LRTable::Action::Reduce)
-          continue;
-        pop((*Heads)[NextPopHead], A.getReduceRule());
+      for (RuleID RID :
+           Params.Table.getReduceRules((*Heads)[NextPopHead]->State)) {
+        const auto &Rule = Params.G.lookupRule(RID);
+        if (Params.Table.canFollow(Rule.Target, Lookahead))
+          pop((*Heads)[NextPopHead], RID, Rule);
       }
     }
   }
@@ -367,21 +366,23 @@ class GLRReduce {
   //  - the head must have only one reduction rule
   //  - the reduction path must be a straight line (no multiple parents)
   // (Roughly this means there's no local ambiguity, so the LR algorithm works).
+  //
+  // Returns true if we successfully consumed the next unpopped head.
   bool popAndPushTrivial() {
     if (!Sequences.empty() || Heads->size() != NextPopHead + 1)
       return false;
     const GSS::Node *Head = Heads->back();
     llvm::Optional<RuleID> RID;
-    for (auto &A : Params.Table.getActions(Head->State, Lookahead)) {
-      if (A.kind() != LRTable::Action::Reduce)
-        continue;
-      if (RID)
+    for (RuleID R : Params.Table.getReduceRules(Head->State)) {
+      if (RID.hasValue())
         return false;
-      RID = A.getReduceRule();
+      RID = R;
     }
     if (!RID)
       return true; // no reductions available, but we've processed the head!
     const auto &Rule = Params.G.lookupRule(*RID);
+    if (!Params.Table.canFollow(Rule.Target, Lookahead))
+      return true; // reduction is not available
     const GSS::Node *Base = Head;
     TempSequence.resize_for_overwrite(Rule.Size);
     for (unsigned I = 0; I < Rule.Size; ++I) {

diff  --git a/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp b/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
index 1f700e53a92f2..3b35d232b8a44 100644
--- a/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
+++ b/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
@@ -10,6 +10,7 @@
 #include "clang-pseudo/grammar/Grammar.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
@@ -21,8 +22,6 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const LRTable::Action &A) {
   switch (A.kind()) {
   case LRTable::Action::Shift:
     return OS << llvm::formatv("shift state {0}", A.getShiftState());
-  case LRTable::Action::Reduce:
-    return OS << llvm::formatv("reduce by rule {0}", A.getReduceRule());
   case LRTable::Action::GoTo:
     return OS << llvm::formatv("go to state {0}", A.getGoToState());
   case LRTable::Action::Sentinel:
@@ -36,9 +35,11 @@ std::string LRTable::dumpStatistics() const {
 Statistics of the LR parsing table:
     number of states: {0}
     number of actions: {1}
-    size of the table (bytes): {2}
+    number of reduces: {2}
+    size of the table (bytes): {3}
 )",
-                       StateOffset.size() - 1, Actions.size(), bytes())
+                       StateOffset.size() - 1, Actions.size(), Reduces.size(),
+                       bytes())
       .str();
 }
 
@@ -52,19 +53,27 @@ std::string LRTable::dumpForTests(const Grammar &G) const {
       SymbolID TokID = tokenSymbol(static_cast<tok::TokenKind>(Terminal));
       for (auto A : find(S, TokID)) {
         if (A.kind() == LRTable::Action::Shift)
-          OS.indent(4) << llvm::formatv("'{0}': shift state {1}\n",
+          OS.indent(4) << llvm::formatv("{0}: shift state {1}\n",
                                         G.symbolName(TokID), A.getShiftState());
-        else if (A.kind() == LRTable::Action::Reduce)
-          OS.indent(4) << llvm::formatv("'{0}': reduce by rule {1} '{2}'\n",
-                                        G.symbolName(TokID), A.getReduceRule(),
-                                        G.dumpRule(A.getReduceRule()));
       }
     }
+    for (RuleID R : getReduceRules(S)) {
+      SymbolID Target = G.lookupRule(R).Target;
+      std::vector<llvm::StringRef> Terminals;
+      for (unsigned Terminal = 0; Terminal < NumTerminals; ++Terminal) {
+        SymbolID TokID = tokenSymbol(static_cast<tok::TokenKind>(Terminal));
+        if (canFollow(Target, TokID))
+          Terminals.push_back(G.symbolName(TokID));
+      }
+      OS.indent(4) << llvm::formatv("{0}: reduce by rule {1} '{2}'\n",
+                                    llvm::join(Terminals, " "), R,
+                                    G.dumpRule(R));
+    }
     for (SymbolID NontermID = 0; NontermID < G.table().Nonterminals.size();
          ++NontermID) {
       if (find(S, NontermID).empty())
         continue;
-      OS.indent(4) << llvm::formatv("'{0}': go to state {1}\n",
+      OS.indent(4) << llvm::formatv("{0}: go to state {1}\n",
                                     G.symbolName(NontermID),
                                     getGoToState(S, NontermID));
     }
@@ -77,18 +86,12 @@ LRTable::getShiftState(StateID State, SymbolID Terminal) const {
   // FIXME: we spend a significant amount of time on misses here.
   // We could consider storing a std::bitset for a cheaper test?
   assert(pseudo::isToken(Terminal) && "expected terminal symbol!");
-  for (const auto &Result : getActions(State, Terminal))
+  for (const auto &Result : find(State, Terminal))
     if (Result.kind() == Action::Shift)
       return Result.getShiftState(); // unique: no shift/shift conflicts.
   return llvm::None;
 }
 
-llvm::ArrayRef<LRTable::Action> LRTable::getActions(StateID State,
-                                                    SymbolID Terminal) const {
-  assert(pseudo::isToken(Terminal) && "expect terminal symbol!");
-  return find(State, Terminal);
-}
-
 LRTable::StateID LRTable::getGoToState(StateID State,
                                        SymbolID Nonterminal) const {
   assert(pseudo::isNonterminal(Nonterminal) && "expected nonterminal symbol!");

diff  --git a/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp b/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
index 7d112b8cebfec..59ea4ce5e3276 100644
--- a/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
+++ b/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
@@ -10,6 +10,7 @@
 #include "clang-pseudo/grammar/LRGraph.h"
 #include "clang-pseudo/grammar/LRTable.h"
 #include "clang/Basic/TokenKinds.h"
+#include "llvm/ADT/SmallSet.h"
 #include <cstdint>
 
 namespace llvm {
@@ -38,13 +39,13 @@ template <> struct DenseMapInfo<clang::pseudo::LRTable::Entry> {
 namespace clang {
 namespace pseudo {
 
-class LRTable::Builder {
-public:
-  Builder(llvm::ArrayRef<std::pair<SymbolID, StateID>> StartStates)
-      : StartStates(StartStates) {}
+struct LRTable::Builder {
+  std::vector<std::pair<SymbolID, StateID>> StartStates;
+  llvm::DenseSet<Entry> Entries;
+  llvm::DenseMap<StateID, llvm::SmallSet<RuleID, 4>> Reduces;
+  std::vector<llvm::DenseSet<SymbolID>> FollowSets;
 
-  bool insert(Entry E) { return Entries.insert(std::move(E)).second; }
-  LRTable build(const GrammarTable &GT, unsigned NumStates) && {
+  LRTable build(unsigned NumStates) && {
     // E.g. given the following parsing table with 3 states and 3 terminals:
     //
     //            a    b     c
@@ -86,16 +87,34 @@ class LRTable::Builder {
         ++SortedIndex;
     }
     Table.StartStates = std::move(StartStates);
+
+    // Compile the follow sets into a bitmap.
+    Table.FollowSets.resize(tok::NUM_TOKENS * FollowSets.size());
+    for (SymbolID NT = 0; NT < FollowSets.size(); ++NT)
+      for (SymbolID Follow : FollowSets[NT])
+        Table.FollowSets.set(NT * tok::NUM_TOKENS + symbolToToken(Follow));
+
+    // Store the reduce actions in a vector partitioned by state.
+    Table.ReduceOffset.reserve(NumStates + 1);
+    std::vector<RuleID> StateRules;
+    for (StateID S = 0; S < NumStates; ++S) {
+      Table.ReduceOffset.push_back(Table.Reduces.size());
+      auto It = Reduces.find(S);
+      if (It == Reduces.end())
+        continue;
+      Table.Reduces.insert(Table.Reduces.end(), It->second.begin(),
+                           It->second.end());
+      std::sort(Table.Reduces.begin() + Table.ReduceOffset.back(),
+                Table.Reduces.end());
+    }
+    Table.ReduceOffset.push_back(Table.Reduces.size());
+
     return Table;
   }
-
-private:
-  llvm::DenseSet<Entry> Entries;
-  std::vector<std::pair<SymbolID, StateID>> StartStates;
 };
 
-LRTable LRTable::buildForTests(const GrammarTable &GT,
-                               llvm::ArrayRef<Entry> Entries) {
+LRTable LRTable::buildForTests(const Grammar &G, llvm::ArrayRef<Entry> Entries,
+                               llvm::ArrayRef<ReduceEntry> Reduces) {
   StateID MaxState = 0;
   for (const auto &Entry : Entries) {
     MaxState = std::max(MaxState, Entry.State);
@@ -104,22 +123,26 @@ LRTable LRTable::buildForTests(const GrammarTable &GT,
     if (Entry.Act.kind() == LRTable::Action::GoTo)
       MaxState = std::max(MaxState, Entry.Act.getGoToState());
   }
-  Builder Build({});
-  for (const Entry &E : Entries)
-    Build.insert(E);
-  return std::move(Build).build(GT, /*NumStates=*/MaxState + 1);
+  Builder Build;
+  Build.Entries.insert(Entries.begin(), Entries.end());
+  for (const ReduceEntry &E : Reduces)
+    Build.Reduces[E.State].insert(E.Rule);
+  Build.FollowSets = followSets(G);
+  return std::move(Build).build(/*NumStates=*/MaxState + 1);
 }
 
 LRTable LRTable::buildSLR(const Grammar &G) {
   auto Graph = LRGraph::buildLR0(G);
-  Builder Build(Graph.startStates());
+  Builder Build;
+  Build.StartStates = Graph.startStates();
   for (const auto &T : Graph.edges()) {
     Action Act = isToken(T.Label) ? Action::shift(T.Dst) : Action::goTo(T.Dst);
-    Build.insert({T.Src, T.Label, Act});
+    Build.Entries.insert({T.Src, T.Label, Act});
   }
+  Build.FollowSets = followSets(G);
   assert(Graph.states().size() <= (1 << StateBits) &&
          "Graph states execceds the maximum limit!");
-  auto FollowSets = followSets(G);
+  // Add reduce actions.
   for (StateID SID = 0; SID < Graph.states().size(); ++SID) {
     for (const Item &I : Graph.states()[SID].Items) {
       // If we've just parsed the start symbol, this means we successfully parse
@@ -127,17 +150,13 @@ LRTable LRTable::buildSLR(const Grammar &G) {
       // LRTable (the GLR parser handles it specifically).
       if (G.lookupRule(I.rule()).Target == G.underscore() && !I.hasNext())
         continue;
-      if (!I.hasNext()) {
+      if (!I.hasNext())
         // If we've reached the end of a rule A := ..., then we can reduce if
         // the next token is in the follow set of A.
-        for (SymbolID Follow : FollowSets[G.lookupRule(I.rule()).Target]) {
-          assert(isToken(Follow));
-          Build.insert({SID, Follow, Action::reduce(I.rule())});
-        }
-      }
+        Build.Reduces[SID].insert(I.rule());
     }
   }
-  return std::move(Build).build(G.table(), Graph.states().size());
+  return std::move(Build).build(Graph.states().size());
 }
 
 } // namespace pseudo

diff  --git a/clang-tools-extra/pseudo/test/lr-build-basic.test b/clang-tools-extra/pseudo/test/lr-build-basic.test
index 36a86ecd9f1f8..eba705623dac4 100644
--- a/clang-tools-extra/pseudo/test/lr-build-basic.test
+++ b/clang-tools-extra/pseudo/test/lr-build-basic.test
@@ -18,11 +18,11 @@ id := IDENTIFIER
 # RUN: clang-pseudo -grammar %s -print-table | FileCheck %s --check-prefix=TABLE
 #      TABLE: LRTable:
 # TABLE-NEXT: State 0
-# TABLE-NEXT:     'IDENTIFIER': shift state 3
-# TABLE-NEXT:     'expr': go to state 1
-# TABLE-NEXT:     'id': go to state 2
+# TABLE-NEXT:     IDENTIFIER: shift state 3
+# TABLE-NEXT:     expr: go to state 1
+# TABLE-NEXT:     id: go to state 2
 # TABLE-NEXT: State 1
 # TABLE-NEXT: State 2
-# TABLE-NEXT:     'EOF': reduce by rule 1 'expr := id'
+# TABLE-NEXT:     EOF: reduce by rule 1 'expr := id'
 # TABLE-NEXT: State 3
-# TABLE-NEXT:     'EOF': reduce by rule 0 'id := IDENTIFIER'
+# TABLE-NEXT:     EOF: reduce by rule 0 'id := IDENTIFIER'

diff  --git a/clang-tools-extra/pseudo/test/lr-build-conflicts.test b/clang-tools-extra/pseudo/test/lr-build-conflicts.test
index 916589572ae48..e5149b865fd00 100644
--- a/clang-tools-extra/pseudo/test/lr-build-conflicts.test
+++ b/clang-tools-extra/pseudo/test/lr-build-conflicts.test
@@ -30,17 +30,15 @@ expr := IDENTIFIER
 # RUN: clang-pseudo -grammar %s -print-table | FileCheck %s --check-prefix=TABLE
 #      TABLE: LRTable:
 # TABLE-NEXT: State 0
-# TABLE-NEXT:     'IDENTIFIER': shift state 2
-# TABLE-NEXT:     'expr': go to state 1
+# TABLE-NEXT:     IDENTIFIER: shift state 2
+# TABLE-NEXT:     expr: go to state 1
 # TABLE-NEXT: State 1
-# TABLE-NEXT:     '-': shift state 3
+# TABLE-NEXT:     -: shift state 3
 # TABLE-NEXT: State 2
-# TABLE-NEXT:     'EOF': reduce by rule 1 'expr := IDENTIFIER'
-# TABLE-NEXT:     '-': reduce by rule 1 'expr := IDENTIFIER'
+# TABLE-NEXT:     EOF -: reduce by rule 1 'expr := IDENTIFIER'
 # TABLE-NEXT: State 3
-# TABLE-NEXT:     'IDENTIFIER': shift state 2
-# TABLE-NEXT:     'expr': go to state 4
+# TABLE-NEXT:     IDENTIFIER: shift state 2
+# TABLE-NEXT:     expr: go to state 4
 # TABLE-NEXT: State 4
-# TABLE-NEXT:     'EOF': reduce by rule 0 'expr := expr - expr'
-# TABLE-NEXT:     '-': shift state 3
-# TABLE-NEXT:     '-': reduce by rule 0 'expr := expr - expr'
+# TABLE-NEXT:     -: shift state 3
+# TABLE-NEXT:     EOF -: reduce by rule 0 'expr := expr - expr'

diff  --git a/clang-tools-extra/pseudo/unittests/GLRTest.cpp b/clang-tools-extra/pseudo/unittests/GLRTest.cpp
index 6e72f1049878e..42e2f0ad66945 100644
--- a/clang-tools-extra/pseudo/unittests/GLRTest.cpp
+++ b/clang-tools-extra/pseudo/unittests/GLRTest.cpp
@@ -31,6 +31,7 @@ namespace {
 
 using Action = LRTable::Action;
 using testing::AllOf;
+using testing::ElementsAre;
 using testing::UnorderedElementsAre;
 
 MATCHER_P(state, StateID, "") { return arg->State == StateID; }
@@ -112,11 +113,13 @@ TEST_F(GLRTest, ShiftMergingHeads) {
 
   buildGrammar({}, {}); // Create a fake empty grammar.
   LRTable T =
-      LRTable::buildForTests(G->table(), /*Entries=*/{
+      LRTable::buildForTests(*G, /*Entries=*/
+                             {
                                  {1, tokenSymbol(tok::semi), Action::shift(4)},
                                  {2, tokenSymbol(tok::semi), Action::shift(4)},
                                  {3, tokenSymbol(tok::semi), Action::shift(5)},
-                             });
+                             },
+                             {});
 
   ForestNode &SemiTerminal = Arena.createTerminal(tok::semi, 0);
   std::vector<const GSS::Node *> NewHeads;
@@ -142,14 +145,15 @@ TEST_F(GLRTest, ReduceConflictsSplitting) {
                {"class-name := IDENTIFIER", "enum-name := IDENTIFIER"});
 
   LRTable Table = LRTable::buildForTests(
-      G->table(), {
-                      {/*State=*/0, id("class-name"), Action::goTo(2)},
-                      {/*State=*/0, id("enum-name"), Action::goTo(3)},
-                      {/*State=*/1, tokenSymbol(tok::l_brace),
-                       Action::reduce(ruleFor("class-name"))},
-                      {/*State=*/1, tokenSymbol(tok::l_brace),
-                       Action::reduce(ruleFor("enum-name"))},
-                  });
+      *G,
+      {
+          {/*State=*/0, id("class-name"), Action::goTo(2)},
+          {/*State=*/0, id("enum-name"), Action::goTo(3)},
+      },
+      {
+          {/*State=*/1, ruleFor("class-name")},
+          {/*State=*/1, ruleFor("enum-name")},
+      });
 
   const auto *GSSNode0 =
       GSStack.addNode(/*State=*/0, /*ForestNode=*/nullptr, /*Parents=*/{});
@@ -157,7 +161,7 @@ TEST_F(GLRTest, ReduceConflictsSplitting) {
       GSStack.addNode(1, &Arena.createTerminal(tok::identifier, 0), {GSSNode0});
 
   std::vector<const GSS::Node *> Heads = {GSSNode1};
-  glrReduce(Heads, tokenSymbol(tok::l_brace), {*G, Table, Arena, GSStack});
+  glrReduce(Heads, tokenSymbol(tok::eof), {*G, Table, Arena, GSStack});
   EXPECT_THAT(Heads, UnorderedElementsAre(
                          GSSNode1,
                          AllOf(state(2), parsedSymbolID(id("class-name")),
@@ -189,15 +193,16 @@ TEST_F(GLRTest, ReduceSplittingDueToMultipleBases) {
       /*Parents=*/{GSSNode2, GSSNode3});
 
   LRTable Table = LRTable::buildForTests(
-      G->table(),
+      *G,
       {
           {/*State=*/2, id("ptr-operator"), Action::goTo(/*NextState=*/5)},
           {/*State=*/3, id("ptr-operator"), Action::goTo(/*NextState=*/6)},
-          {/*State=*/4, tokenSymbol(tok::identifier),
-           Action::reduce(ruleFor("ptr-operator"))},
+      },
+      {
+          {/*State=*/4, ruleFor("ptr-operator")},
       });
   std::vector<const GSS::Node *> Heads = {GSSNode4};
-  glrReduce(Heads, tokenSymbol(tok::identifier), {*G, Table, Arena, GSStack});
+  glrReduce(Heads, tokenSymbol(tok::eof), {*G, Table, Arena, GSStack});
 
   EXPECT_THAT(Heads, UnorderedElementsAre(
                          GSSNode4,
@@ -242,17 +247,17 @@ TEST_F(GLRTest, ReduceJoiningWithMultipleBases) {
 
   // FIXME: figure out a way to get rid of the hard-coded reduce RuleID!
   LRTable Table = LRTable::buildForTests(
-      G->table(),
+      *G,
       {
           {/*State=*/1, id("type-name"), Action::goTo(/*NextState=*/5)},
           {/*State=*/2, id("type-name"), Action::goTo(/*NextState=*/5)},
-          {/*State=*/3, tokenSymbol(tok::l_paren),
-           Action::reduce(/* type-name := class-name */ 0)},
-          {/*State=*/4, tokenSymbol(tok::l_paren),
-           Action::reduce(/* type-name := enum-name */ 1)},
+      },
+      {
+          {/*State=*/3, /* type-name := class-name */ 0},
+          {/*State=*/4, /* type-name := enum-name */ 1},
       });
   std::vector<const GSS::Node *> Heads = {GSSNode3, GSSNode4};
-  glrReduce(Heads, tokenSymbol(tok::l_paren), {*G, Table, Arena, GSStack});
+  glrReduce(Heads, tokenSymbol(tok::eof), {*G, Table, Arena, GSStack});
 
   // Verify that the stack heads are joint at state 5 after reduces.
   EXPECT_THAT(Heads, UnorderedElementsAre(GSSNode3, GSSNode4,
@@ -299,16 +304,17 @@ TEST_F(GLRTest, ReduceJoiningWithSameBase) {
                       /*Parents=*/{GSSNode2});
 
   // FIXME: figure out a way to get rid of the hard-coded reduce RuleID!
-  LRTable Table = LRTable::buildForTests(
-      G->table(), {
-                      {/*State=*/0, id("pointer"), Action::goTo(5)},
-                      {3, tokenSymbol(tok::l_paren),
-                       Action::reduce(/* pointer := class-name */ 0)},
-                      {4, tokenSymbol(tok::l_paren),
-                       Action::reduce(/* pointer := enum-name */ 1)},
-                  });
+  LRTable Table =
+      LRTable::buildForTests(*G,
+                             {
+                                 {/*State=*/0, id("pointer"), Action::goTo(5)},
+                             },
+                             {
+                                 {3, /* pointer := class-name */ 0},
+                                 {4, /* pointer := enum-name */ 1},
+                             });
   std::vector<const GSS::Node *> Heads = {GSSNode3, GSSNode4};
-  glrReduce(Heads, tokenSymbol(tok::l_paren), {*G, Table, Arena, GSStack});
+  glrReduce(Heads, tokenSymbol(tok::eof), {*G, Table, Arena, GSStack});
 
   EXPECT_THAT(
       Heads, UnorderedElementsAre(GSSNode3, GSSNode4,
@@ -325,6 +331,38 @@ TEST_F(GLRTest, ReduceJoiningWithSameBase) {
             "[  1, end)   └─* := tok[1]\n");
 }
 
+TEST_F(GLRTest, ReduceLookahead) {
+  // A term can be followed by +, but not by -.
+  buildGrammar({"sum", "term"}, {"expr := term + term", "term := IDENTIFIER"});
+  LRTable Table =
+      LRTable::buildForTests(*G,
+                             {
+                                 {/*State=*/0, id("term"), Action::goTo(2)},
+                             },
+                             {
+                                 {/*State=*/1, 0},
+                             });
+
+  auto *Identifier = &Arena.createTerminal(tok::identifier, /*Start=*/0);
+
+  const auto *Root =
+      GSStack.addNode(/*State=*/0, /*ForestNode=*/nullptr, /*Parents=*/{});
+  const auto *GSSNode1 =
+      GSStack.addNode(/*State=*/1, /*ForestNode=*/Identifier, {Root});
+
+  // When the lookahead is +, reduce is performed.
+  std::vector<const GSS::Node *> Heads = {GSSNode1};
+  glrReduce(Heads, tokenSymbol(tok::plus), {*G, Table, Arena, GSStack});
+  EXPECT_THAT(Heads,
+              ElementsAre(GSSNode1, AllOf(state(2), parsedSymbolID(id("term")),
+                                          parents(Root))));
+
+  // When the lookahead is -, reduce is not performed.
+  Heads = {GSSNode1};
+  glrReduce(Heads, tokenSymbol(tok::minus), {*G, Table, Arena, GSStack});
+  EXPECT_THAT(Heads, ElementsAre(GSSNode1));
+}
+
 TEST_F(GLRTest, PerfectForestNodeSharing) {
   // Run the GLR on a simple grammar and test that we build exactly one forest
   // node per (SymbolID, token range).

diff  --git a/clang-tools-extra/pseudo/unittests/LRTableTest.cpp b/clang-tools-extra/pseudo/unittests/LRTableTest.cpp
index f4ba9be49db1f..e44be1af01547 100644
--- a/clang-tools-extra/pseudo/unittests/LRTableTest.cpp
+++ b/clang-tools-extra/pseudo/unittests/LRTableTest.cpp
@@ -9,6 +9,7 @@
 #include "clang-pseudo/grammar/LRTable.h"
 #include "clang-pseudo/grammar/Grammar.h"
 #include "clang/Basic/TokenKinds.h"
+#include "llvm/Testing/Support/SupportHelpers.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include <vector>
@@ -17,36 +18,59 @@ namespace clang {
 namespace pseudo {
 namespace {
 
-using testing::IsEmpty;
-using testing::UnorderedElementsAre;
+using llvm::ValueIs;
+using testing::ElementsAre;
 using Action = LRTable::Action;
 
 TEST(LRTable, Builder) {
-  GrammarTable GTable;
-
-  //           eof   semi  ...
-  // +-------+----+-------+---
-  // |state0 |    | s0,r0 |...
-  // |state1 | acc|       |...
-  // |state2 |    |  r1   |...
-  // +-------+----+-------+---
+  std::vector<std::string> GrammarDiags;
+  auto G = Grammar::parseBNF(R"bnf(
+    _ := expr            # rule 0
+    expr := term         # rule 1
+    expr := expr + term  # rule 2
+    term := IDENTIFIER   # rule 3
+  )bnf",
+                             GrammarDiags);
+  EXPECT_THAT(GrammarDiags, testing::IsEmpty());
+
+  SymbolID Term = *G->findNonterminal("term");
+  SymbolID Eof = tokenSymbol(tok::eof);
+  SymbolID Identifier = tokenSymbol(tok::identifier);
+  SymbolID Plus = tokenSymbol(tok::plus);
+
+  //           eof  IDENT   term
+  // +-------+----+-------+------+
+  // |state0 |    | s0    |      |
+  // |state1 |    |       | g3   |
+  // |state2 |    |       |      |
+  // +-------+----+-------+------+-------
   std::vector<LRTable::Entry> Entries = {
-      {/* State */ 0, tokenSymbol(tok::semi), Action::shift(0)},
-      {/* State */ 0, tokenSymbol(tok::semi), Action::reduce(0)},
-      {/* State */ 1, tokenSymbol(tok::eof), Action::reduce(2)},
-      {/* State */ 2, tokenSymbol(tok::semi), Action::reduce(1)}};
-  GrammarTable GT;
-  LRTable T = LRTable::buildForTests(GT, Entries);
-  EXPECT_THAT(T.find(0, tokenSymbol(tok::eof)), IsEmpty());
-  EXPECT_THAT(T.find(0, tokenSymbol(tok::semi)),
-              UnorderedElementsAre(Action::shift(0), Action::reduce(0)));
-  EXPECT_THAT(T.find(1, tokenSymbol(tok::eof)),
-              UnorderedElementsAre(Action::reduce(2)));
-  EXPECT_THAT(T.find(1, tokenSymbol(tok::semi)), IsEmpty());
-  EXPECT_THAT(T.find(2, tokenSymbol(tok::semi)),
-              UnorderedElementsAre(Action::reduce(1)));
+      {/* State */ 0, Identifier, Action::shift(0)},
+      {/* State */ 1, Term, Action::goTo(3)},
+  };
+  std::vector<LRTable::ReduceEntry> ReduceEntries = {
+      {/*State=*/0, /*Rule=*/0},
+      {/*State=*/1, /*Rule=*/2},
+      {/*State=*/2, /*Rule=*/1},
+  };
+  LRTable T = LRTable::buildForTests(*G, Entries, ReduceEntries);
+  EXPECT_EQ(T.getShiftState(0, Eof), llvm::None);
+  EXPECT_THAT(T.getShiftState(0, Identifier), ValueIs(0));
+  EXPECT_THAT(T.getReduceRules(0), ElementsAre(0));
+
+  EXPECT_EQ(T.getShiftState(1, Eof), llvm::None);
+  EXPECT_EQ(T.getShiftState(1, Identifier), llvm::None);
+  EXPECT_EQ(T.getGoToState(1, Term), 3);
+  EXPECT_THAT(T.getReduceRules(1), ElementsAre(2));
+
   // Verify the behaivor for other non-available-actions terminals.
-  EXPECT_THAT(T.find(2, tokenSymbol(tok::kw_int)), IsEmpty());
+  SymbolID Int = tokenSymbol(tok::kw_int);
+  EXPECT_EQ(T.getShiftState(2, Int), llvm::None);
+
+  // Check follow sets.
+  EXPECT_TRUE(T.canFollow(Term, Plus));
+  EXPECT_TRUE(T.canFollow(Term, Eof));
+  EXPECT_FALSE(T.canFollow(Term, Int));
 }
 
 } // namespace


        


More information about the cfe-commits mailing list