This revision was landed with ongoing or failed builds.
This revision was automatically updated to reflect the committed changes.
sammccall marked 3 inline comments as done.
Closed by commit rGb37dafd5dc83: [pseudo] Store shift and goto actions in a 
compact structure with faster lookup. (authored by sammccall).

Changed prior to commit:
  https://reviews.llvm.org/D128485?vs=439589&id=442123#toc

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D128485/new/

https://reviews.llvm.org/D128485

Files:
  clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
  clang-tools-extra/pseudo/lib/GLR.cpp
  clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
  clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
  clang-tools-extra/pseudo/unittests/LRTableTest.cpp

Index: clang-tools-extra/pseudo/unittests/LRTableTest.cpp
===================================================================
--- clang-tools-extra/pseudo/unittests/LRTableTest.cpp
+++ clang-tools-extra/pseudo/unittests/LRTableTest.cpp
@@ -60,7 +60,7 @@
 
   EXPECT_EQ(T.getShiftState(1, Eof), llvm::None);
   EXPECT_EQ(T.getShiftState(1, Identifier), llvm::None);
-  EXPECT_EQ(T.getGoToState(1, Term), 3);
+  EXPECT_THAT(T.getGoToState(1, Term), ValueIs(3));
   EXPECT_THAT(T.getReduceRules(1), ElementsAre(2));
 
   // Verify the behaivor for other non-available-actions terminals.
Index: clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
===================================================================
--- clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
+++ clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
@@ -45,49 +45,24 @@
   llvm::DenseMap<StateID, llvm::SmallSet<RuleID, 4>> Reduces;
   std::vector<llvm::DenseSet<SymbolID>> FollowSets;
 
-  LRTable build(unsigned NumStates) && {
-    // E.g. given the following parsing table with 3 states and 3 terminals:
-    //
-    //            a    b     c
-    // +-------+----+-------+-+
-    // |state0 |    | s0,r0 | |
-    // |state1 | acc|       | |
-    // |state2 |    |  r1   | |
-    // +-------+----+-------+-+
-    //
-    // The final LRTable:
-    //  - StateOffset: [s0] = 0, [s1] = 2, [s2] = 3, [sentinel] = 4
-    //  - Symbols:     [ b,   b,   a,  b]
-    //    Actions:     [ s0, r0, acc, r1]
-    //                   ~~~~~~ range for state 0
-    //                           ~~~~ range for state 1
-    //                                ~~ range for state 2
-    // First step, we sort all entries by (State, Symbol, Action).
-    std::vector<Entry> Sorted(Entries.begin(), Entries.end());
-    llvm::sort(Sorted, [](const Entry &L, const Entry &R) {
-      return std::forward_as_tuple(L.State, L.Symbol, L.Act.opaque()) <
-             std::forward_as_tuple(R.State, R.Symbol, R.Act.opaque());
-    });
-
+  LRTable build(unsigned NumStates, unsigned NumNonterminals) && {
     LRTable Table;
-    Table.Actions.reserve(Sorted.size());
-    Table.Symbols.reserve(Sorted.size());
-    // We are good to finalize the States and Actions.
-    for (const auto &E : Sorted) {
-      Table.Actions.push_back(E.Act);
-      Table.Symbols.push_back(E.Symbol);
-    }
-    // Initialize the terminal and nonterminal offset, all ranges are empty by
-    // default.
-    Table.StateOffset = std::vector<uint32_t>(NumStates + 1, 0);
-    size_t SortedIndex = 0;
-    for (StateID State = 0; State < Table.StateOffset.size(); ++State) {
-      Table.StateOffset[State] = SortedIndex;
-      while (SortedIndex < Sorted.size() && Sorted[SortedIndex].State == State)
-        ++SortedIndex;
-    }
     Table.StartStates = std::move(StartStates);
 
+    // Compile the goto and shift actions into transition tables.
+    llvm::DenseMap<unsigned, SymbolID> Gotos;
+    llvm::DenseMap<unsigned, SymbolID> Shifts;
+    for (const auto &E : Entries) {
+      if (E.Act.kind() == Action::Shift)
+        Shifts.try_emplace(shiftIndex(E.State, E.Symbol, NumStates),
+                           E.Act.getShiftState());
+      else if (E.Act.kind() == Action::GoTo)
+        Gotos.try_emplace(gotoIndex(E.State, E.Symbol, NumStates),
+                          E.Act.getGoToState());
+    }
+    Table.Shifts = TransitionTable(Shifts, NumStates * NumTerminals);
+    Table.Gotos = TransitionTable(Gotos, NumStates * NumNonterminals);
+
     // Compile the follow sets into a bitmap.
     Table.FollowSets.resize(tok::NUM_TOKENS * FollowSets.size());
     for (SymbolID NT = 0; NT < FollowSets.size(); ++NT)
@@ -128,7 +103,8 @@
   for (const ReduceEntry &E : Reduces)
     Build.Reduces[E.State].insert(E.Rule);
   Build.FollowSets = followSets(G);
-  return std::move(Build).build(/*NumStates=*/MaxState + 1);
+  return std::move(Build).build(/*NumStates=*/MaxState + 1,
+                                G.table().Nonterminals.size());
 }
 
 LRTable LRTable::buildSLR(const Grammar &G) {
@@ -156,7 +132,8 @@
         Build.Reduces[SID].insert(I.rule());
     }
   }
-  return std::move(Build).build(Graph.states().size());
+  return std::move(Build).build(Graph.states().size(),
+                                G.table().Nonterminals.size());
 }
 
 } // namespace pseudo
Index: clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
===================================================================
--- clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
+++ clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
@@ -34,11 +34,10 @@
   return llvm::formatv(R"(
 Statistics of the LR parsing table:
     number of states: {0}
-    number of actions: {1}
-    number of reduces: {2}
-    size of the table (bytes): {3}
+    number of actions: shift={1} goto={2} reduce={3}
+    size of the table (bytes): {4}
 )",
-                       StateOffset.size() - 1, Actions.size(), Reduces.size(),
+                       numStates(), Shifts.size(), Gotos.size(), Reduces.size(),
                        bytes())
       .str();
 }
@@ -47,15 +46,13 @@
   std::string Result;
   llvm::raw_string_ostream OS(Result);
   OS << "LRTable:\n";
-  for (StateID S = 0; S < StateOffset.size() - 1; ++S) {
+  for (StateID S = 0; S < numStates(); ++S) {
     OS << llvm::formatv("State {0}\n", S);
     for (uint16_t Terminal = 0; Terminal < NumTerminals; ++Terminal) {
       SymbolID TokID = tokenSymbol(static_cast<tok::TokenKind>(Terminal));
-      for (auto A : find(S, TokID)) {
-        if (A.kind() == LRTable::Action::Shift)
-          OS.indent(4) << llvm::formatv("{0}: shift state {1}\n",
-                                        G.symbolName(TokID), A.getShiftState());
-      }
+      if (auto SS = getShiftState(S, TokID))
+        OS.indent(4) << llvm::formatv("{0}: shift state {1}\n",
+                                      G.symbolName(TokID), SS);
     }
     for (RuleID R : getReduceRules(S)) {
       SymbolID Target = G.lookupRule(R).Target;
@@ -71,55 +68,15 @@
     }
     for (SymbolID NontermID = 0; NontermID < G.table().Nonterminals.size();
          ++NontermID) {
-      if (find(S, NontermID).empty())
-        continue;
-      OS.indent(4) << llvm::formatv("{0}: go to state {1}\n",
-                                    G.symbolName(NontermID),
-                                    getGoToState(S, NontermID));
+      if (auto GS = getGoToState(S, NontermID)) {
+        OS.indent(4) << llvm::formatv("{0}: go to state {1}\n",
+                                      G.symbolName(NontermID), *GS);
+      }
     }
   }
   return OS.str();
 }
 
-llvm::Optional<LRTable::StateID>
-LRTable::getShiftState(StateID State, SymbolID Terminal) const {
-  // FIXME: we spend a significant amount of time on misses here.
-  // We could consider storing a std::bitset for a cheaper test?
-  assert(pseudo::isToken(Terminal) && "expected terminal symbol!");
-  for (const auto &Result : find(State, Terminal))
-    if (Result.kind() == Action::Shift)
-      return Result.getShiftState(); // unique: no shift/shift conflicts.
-  return llvm::None;
-}
-
-LRTable::StateID LRTable::getGoToState(StateID State,
-                                       SymbolID Nonterminal) const {
-  assert(pseudo::isNonterminal(Nonterminal) && "expected nonterminal symbol!");
-  auto Result = find(State, Nonterminal);
-  assert(Result.size() == 1 && Result.front().kind() == Action::GoTo);
-  return Result.front().getGoToState();
-}
-
-llvm::ArrayRef<LRTable::Action> LRTable::find(StateID Src, SymbolID ID) const {
-  assert(Src + 1u < StateOffset.size());
-  std::pair<size_t, size_t> Range =
-      std::make_pair(StateOffset[Src], StateOffset[Src + 1]);
-  auto SymbolRange = llvm::makeArrayRef(Symbols.data() + Range.first,
-                                        Symbols.data() + Range.second);
-
-  assert(llvm::is_sorted(SymbolRange) &&
-         "subrange of the Symbols should be sorted!");
-  const LRTable::StateID *Start =
-      llvm::partition_point(SymbolRange, [&ID](SymbolID S) { return S < ID; });
-  if (Start == SymbolRange.end())
-    return {};
-  const LRTable::StateID *End = Start;
-  while (End != SymbolRange.end() && *End == ID)
-    ++End;
-  return llvm::makeArrayRef(&Actions[Start - Symbols.data()],
-                            /*length=*/End - Start);
-}
-
 LRTable::StateID LRTable::getStartState(SymbolID Target) const {
   assert(llvm::is_sorted(StartStates) && "StartStates must be sorted!");
   auto It = llvm::partition_point(
Index: clang-tools-extra/pseudo/lib/GLR.cpp
===================================================================
--- clang-tools-extra/pseudo/lib/GLR.cpp
+++ clang-tools-extra/pseudo/lib/GLR.cpp
@@ -318,9 +318,11 @@
     do {
       const PushSpec &Push = Sequences.top().second;
       FamilySequences.emplace_back(Sequences.top().first.Rule, *Push.Seq);
-      for (const GSS::Node *Base : Push.LastPop->parents())
-        FamilyBases.emplace_back(
-            Params.Table.getGoToState(Base->State, F.Symbol), Base);
+      for (const GSS::Node *Base : Push.LastPop->parents()) {
+        auto NextState = Params.Table.getGoToState(Base->State, F.Symbol);
+        assert(NextState.hasValue() && "goto must succeed after reduce!");
+        FamilyBases.emplace_back(*NextState, Base);
+      }
 
       Sequences.pop();
     } while (!Sequences.empty() && Sequences.top().first == F);
@@ -393,8 +395,9 @@
     }
     const ForestNode *Parsed =
         &Params.Forest.createSequence(Rule.Target, *RID, TempSequence);
-    StateID NextState = Params.Table.getGoToState(Base->State, Rule.Target);
-    Heads->push_back(Params.GSStack.addNode(NextState, Parsed, {Base}));
+    auto NextState = Params.Table.getGoToState(Base->State, Rule.Target);
+    assert(NextState.hasValue() && "goto must succeed after reduce!");
+    Heads->push_back(Params.GSStack.addNode(*NextState, Parsed, {Base}));
     return true;
   }
 };
@@ -444,7 +447,8 @@
   }
   LLVM_DEBUG(llvm::dbgs() << llvm::formatv("Reached eof\n"));
 
-  StateID AcceptState = Params.Table.getGoToState(StartState, StartSymbol);
+  auto AcceptState = Params.Table.getGoToState(StartState, StartSymbol);
+  assert(AcceptState.hasValue() && "goto must succeed after start symbol!");
   const ForestNode *Result = nullptr;
   for (const auto *Head : Heads) {
     if (Head->State == AcceptState) {
Index: clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
===================================================================
--- clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
+++ clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
@@ -40,6 +40,8 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/Support/Capacity.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
 #include <cstdint>
 #include <vector>
 
@@ -123,12 +125,18 @@
 
   // Returns the state after we reduce a nonterminal.
   // Expected to be called by LR parsers.
-  // REQUIRES: Nonterminal is valid here.
-  StateID getGoToState(StateID State, SymbolID Nonterminal) const;
+  // If the nonterminal is invalid here, returns None.
+  llvm::Optional<StateID> getGoToState(StateID State,
+                                       SymbolID Nonterminal) const {
+    return Gotos.get(gotoIndex(State, Nonterminal, numStates()));
+  }
   // Returns the state after we shift a terminal.
   // Expected to be called by LR parsers.
   // If the terminal is invalid here, returns None.
-  llvm::Optional<StateID> getShiftState(StateID State, SymbolID Terminal) const;
+  llvm::Optional<StateID> getShiftState(StateID State,
+                                        SymbolID Terminal) const {
+    return Shifts.get(shiftIndex(State, Terminal, numStates()));
+  }
 
   // Returns the possible reductions from a state.
   //
@@ -164,9 +172,7 @@
   StateID getStartState(SymbolID StartSymbol) const;
 
   size_t bytes() const {
-    return sizeof(*this) + llvm::capacity_in_bytes(Actions) +
-           llvm::capacity_in_bytes(Symbols) +
-           llvm::capacity_in_bytes(StateOffset) +
+    return sizeof(*this) + Gotos.bytes() + Shifts.bytes() +
            llvm::capacity_in_bytes(Reduces) +
            llvm::capacity_in_bytes(ReduceOffset) +
            llvm::capacity_in_bytes(FollowSets);
@@ -194,22 +200,92 @@
                                llvm::ArrayRef<ReduceEntry>);
 
 private:
-  // Looks up actions stored in the generic table.
-  llvm::ArrayRef<Action> find(StateID State, SymbolID Symbol) const;
-
-  // Conceptually the LR table is a multimap from (State, SymbolID) => Action.
-  // Our physical representation is quite different for compactness.
-
-  // Index is StateID, value is the offset into Symbols/Actions
-  // where the entries for this state begin.
-  // Give a state id, the corresponding half-open range of Symbols/Actions is
-  // [StateOffset[id], StateOffset[id+1]).
-  std::vector<uint32_t> StateOffset;
-  // Parallel to Actions, the value is SymbolID (columns of the matrix).
-  // Grouped by the StateID, and only subranges are sorted.
-  std::vector<SymbolID> Symbols;
-  // A flat list of available actions, sorted by (State, SymbolID).
-  std::vector<Action> Actions;
+  unsigned numStates() const { return ReduceOffset.size() - 1; }
+
+  // A map from unsigned key => StateID, used to store actions.
+  // The keys should be sequential but the values are somewhat sparse.
+  //
+  // In practice, the keys encode (origin state, symbol) pairs, and the values
+  // are the state we should move to after seeing that symbol.
+  //
+  // We store one bit for presence/absence of the value for each key.
+  // At every 64th key, we store the offset into the table of values.
+  //   e.g. key 0x500 is checkpoint 0x500/64 = 20
+  //                     Checkpoints[20] = 34
+  //        get(0x500) = Values[34]                (assuming it has a value)
+  // To look up values in between, we count the set bits:
+  //        get(0x509) has a value if HasValue[20] & (1<<9)
+  //        #values between 0x500 and 0x509: popcnt(HasValue[20] & (1<<9 - 1))
+  //        get(0x509) = Values[34 + popcnt(...)]
+  //
+  // Overall size is 1.25 bits/key + 16 bits/value.
+  // Lookup is constant time with a low factor (no hashing).
+  class TransitionTable {
+    using Word = uint64_t;
+    constexpr static unsigned WordBits = CHAR_BIT * sizeof(Word);
+
+    std::vector<StateID> Values;
+    std::vector<Word> HasValue;
+    std::vector<uint16_t> Checkpoints;
+
+  public:
+    TransitionTable() = default;
+    TransitionTable(const llvm::DenseMap<unsigned, StateID> &Entries,
+                    unsigned NumKeys) {
+      assert(
+          Entries.size() <
+              std::numeric_limits<decltype(Checkpoints)::value_type>::max() &&
+          "16 bits too small for value offsets!");
+      unsigned NumWords = (NumKeys + WordBits - 1) / WordBits;
+      HasValue.resize(NumWords, 0);
+      Checkpoints.reserve(NumWords);
+      Values.reserve(Entries.size());
+      for (unsigned I = 0; I < NumKeys; ++I) {
+        if ((I % WordBits) == 0)
+          Checkpoints.push_back(Values.size());
+        auto It = Entries.find(I);
+        if (It != Entries.end()) {
+          HasValue[I / WordBits] |= (Word(1) << (I % WordBits));
+          Values.push_back(It->second);
+        }
+      }
+    }
+
+    llvm::Optional<StateID> get(unsigned Key) const {
+      // Do we have a value for this key?
+      Word KeyMask = Word(1) << (Key % WordBits);
+      unsigned KeyWord = Key / WordBits;
+      if ((HasValue[KeyWord] & KeyMask) == 0)
+        return llvm::None;
+      // Count the number of values since the checkpoint.
+      Word BelowKeyMask = KeyMask - 1;
+      unsigned CountSinceCheckpoint =
+          llvm::countPopulation(HasValue[KeyWord] & BelowKeyMask);
+      // Find the value relative to the last checkpoint.
+      return Values[Checkpoints[KeyWord] + CountSinceCheckpoint];
+    }
+
+    unsigned size() const { return Values.size(); }
+
+    size_t bytes() const {
+      return llvm::capacity_in_bytes(HasValue) +
+             llvm::capacity_in_bytes(Values) +
+             llvm::capacity_in_bytes(Checkpoints);
+    }
+  };
+  // Shift and Goto tables are keyed by encoded (State, Symbol).
+  static unsigned shiftIndex(StateID State, SymbolID Terminal,
+                             unsigned NumStates) {
+    return NumStates * symbolToToken(Terminal) + State;
+  }
+  static unsigned gotoIndex(StateID State, SymbolID Nonterminal,
+                            unsigned NumStates) {
+    assert(isNonterminal(Nonterminal));
+    return NumStates * Nonterminal + State;
+  }
+  TransitionTable Shifts;
+  TransitionTable Gotos;
+
   // A sorted table, storing the start state for each target parsing symbol.
   std::vector<std::pair<SymbolID, StateID>> StartStates;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to