[llvm] [TableGen][DecoderEmitter] Rework table construction/emission (PR #155889)

Sergei Barannikov via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 13 17:50:50 PDT 2025


https://github.com/s-barannikov updated https://github.com/llvm/llvm-project/pull/155889

>From d894c250844986d6a5015c3a0280921f68e48c62 Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Sun, 14 Sep 2025 03:19:24 +0300
Subject: [PATCH] tmp

---
 llvm/include/llvm/MC/MCDecoderOps.h    |   20 +-
 llvm/utils/TableGen/DecoderEmitter.cpp | 1122 +++++++++++++-----------
 2 files changed, 634 insertions(+), 508 deletions(-)

diff --git a/llvm/include/llvm/MC/MCDecoderOps.h b/llvm/include/llvm/MC/MCDecoderOps.h
index 790ff3eb4f333..4e06deb0eacee 100644
--- a/llvm/include/llvm/MC/MCDecoderOps.h
+++ b/llvm/include/llvm/MC/MCDecoderOps.h
@@ -13,19 +13,15 @@
 namespace llvm::MCD {
 
 // Disassembler state machine opcodes.
-// nts_t is either uint16_t or uint24_t based on whether large decoder table is
-// enabled.
 enum DecoderOps {
-  OPC_Scope = 1,         // OPC_Scope(nts_t NumToSkip)
-  OPC_ExtractField,      // OPC_ExtractField(uleb128 Start, uint8_t Len)
-  OPC_FilterValueOrSkip, // OPC_FilterValueOrSkip(uleb128 Val, nts_t NumToSkip)
-  OPC_FilterValue,       // OPC_FilterValue(uleb128 Val)
-  OPC_CheckField,        // OPC_CheckField(uleb128 Start, uint8_t Len,
-                         //                uleb128 Val)
-  OPC_CheckPredicate,    // OPC_CheckPredicate(uleb128 PIdx)
-  OPC_Decode,            // OPC_Decode(uleb128 Opcode, uleb128 DIdx)
-  OPC_TryDecode,         // OPC_TryDecode(uleb128 Opcode, uleb128 DIdx)
-  OPC_SoftFail,          // OPC_SoftFail(uleb128 PMask, uleb128 NMask)
+  OPC_Scope = 1,      // OPC_Scope(uleb128 Size)
+  OPC_SwitchField,    // OPC_SwitchField(uleb128 Start, uint8_t Len,
+                      //                 [uleb128 Val, uleb128 Size]...)
+  OPC_CheckField,     // OPC_CheckField(uleb128 Start, uint8_t Len, uleb128 Val)
+  OPC_CheckPredicate, // OPC_CheckPredicate(uleb128 PIdx)
+  OPC_Decode,         // OPC_Decode(uleb128 Opcode, uleb128 DIdx)
+  OPC_TryDecode,      // OPC_TryDecode(uleb128 Opcode, uleb128 DIdx)
+  OPC_SoftFail,       // OPC_SoftFail(uleb128 PMask, uleb128 NMask)
 };
 
 } // namespace llvm::MCD
diff --git a/llvm/utils/TableGen/DecoderEmitter.cpp b/llvm/utils/TableGen/DecoderEmitter.cpp
index a8a9036a1a7f4..c113673feea08 100644
--- a/llvm/utils/TableGen/DecoderEmitter.cpp
+++ b/llvm/utils/TableGen/DecoderEmitter.cpp
@@ -23,12 +23,10 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
-#include "llvm/MC/MCDecoderOps.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -81,12 +79,6 @@ static cl::opt<SuppressLevel> DecoderEmitterSuppressDuplicates(
             "significantly reducing Table Duplications")),
     cl::init(SUPPRESSION_DISABLE), cl::cat(DisassemblerEmitterCat));
 
-static cl::opt<bool> LargeTable(
-    "large-decoder-table",
-    cl::desc("Use large decoder table format. This uses 24 bits for offset\n"
-             "in the table instead of the default 16 bits."),
-    cl::init(false), cl::cat(DisassemblerEmitterCat));
-
 static cl::opt<bool> UseFnTableInDecodeToMCInst(
     "use-fn-table-in-decode-to-mcinst",
     cl::desc(
@@ -124,8 +116,6 @@ STATISTIC(NumInstructions, "Number of instructions considered");
 STATISTIC(NumEncodingsSupported, "Number of encodings supported");
 STATISTIC(NumEncodingsOmitted, "Number of encodings omitted");
 
-static unsigned getNumToSkipInBytes() { return LargeTable ? 3 : 2; }
-
 /// Similar to KnownBits::print(), but allows you to specify a character to use
 /// to print unknown bits.
 static void printKnownBits(raw_ostream &OS, const KnownBits &Bits,
@@ -275,70 +265,185 @@ class LessEncodingIDByWidth {
   }
 };
 
-typedef SmallSetVector<CachedHashString, 16> PredicateSet;
-typedef SmallSetVector<CachedHashString, 16> DecoderSet;
+class DecoderTreeNode {
+public:
+  virtual ~DecoderTreeNode() = default;
+
+  enum KindTy {
+    AnyOf,
+    AllOf,
+    CheckField,
+    SwitchField,
+    CheckPredicate,
+    SoftFail,
+    Decode,
+  };
+
+  KindTy getKind() const { return Kind; }
+
+protected:
+  explicit DecoderTreeNode(KindTy Kind) : Kind(Kind) {}
+
+private:
+  KindTy Kind;
+};
+
+class AllOfNode : public DecoderTreeNode {
+  SmallVector<std::unique_ptr<DecoderTreeNode>, 0> Children;
+
+  static const DecoderTreeNode *
+  mapElement(decltype(Children)::const_reference Element) {
+    return Element.get();
+  }
 
-class DecoderTable {
 public:
-  DecoderTable() { Data.reserve(16384); }
+  AllOfNode() : DecoderTreeNode(AllOf) {}
+
+  void addChild(std::unique_ptr<DecoderTreeNode> Child) {
+    Children.push_back(std::move(Child));
+  }
 
-  void clear() { Data.clear(); }
-  size_t size() const { return Data.size(); }
-  const uint8_t *data() const { return Data.data(); }
+  using child_iterator = mapped_iterator<decltype(Children)::const_iterator,
+                                         decltype(&mapElement)>;
 
-  using const_iterator = std::vector<uint8_t>::const_iterator;
-  const_iterator begin() const { return Data.begin(); }
-  const_iterator end() const { return Data.end(); }
+  child_iterator child_begin() const {
+    return child_iterator(Children.begin(), mapElement);
+  }
 
-  /// Inserts a state machine opcode into the table.
-  void insertOpcode(MCD::DecoderOps Opcode) { Data.push_back(Opcode); }
+  child_iterator child_end() const {
+    return child_iterator(Children.end(), mapElement);
+  }
 
-  /// Inserts a uint8 encoded value into the table.
-  void insertUInt8(unsigned Value) {
-    assert(isUInt<8>(Value));
-    Data.push_back(Value);
+  iterator_range<child_iterator> children() const {
+    return make_range(child_begin(), child_end());
   }
+};
 
-  /// Inserts a ULEB128 encoded value into the table.
-  void insertULEB128(uint64_t Value) {
-    // Encode and emit the value to filter against.
-    uint8_t Buffer[16];
-    unsigned Len = encodeULEB128(Value, Buffer);
-    Data.insert(Data.end(), Buffer, Buffer + Len);
+class AnyOfNode : public DecoderTreeNode {
+  SmallVector<std::unique_ptr<DecoderTreeNode>, 0> Children;
+
+  static const DecoderTreeNode *
+  mapElement(decltype(Children)::const_reference Element) {
+    return Element.get();
   }
 
-  // Insert space for `NumToSkip` and return the position
-  // in the table for patching.
-  size_t insertNumToSkip() {
-    size_t Size = Data.size();
-    Data.insert(Data.end(), getNumToSkipInBytes(), 0);
-    return Size;
+public:
+  AnyOfNode() : DecoderTreeNode(AnyOf) {}
+
+  void addChild(std::unique_ptr<DecoderTreeNode> N) {
+    Children.push_back(std::move(N));
   }
 
-  void patchNumToSkip(size_t FixupIdx, uint32_t DestIdx) {
-    // Calculate the distance from the byte following the fixup entry byte
-    // to the destination. The Target is calculated from after the
-    // `getNumToSkipInBytes()`-byte NumToSkip entry itself, so subtract
-    // `getNumToSkipInBytes()` from the displacement here to account for that.
-    assert(DestIdx >= FixupIdx + getNumToSkipInBytes() &&
-           "Expecting a forward jump in the decoding table");
-    uint32_t Delta = DestIdx - FixupIdx - getNumToSkipInBytes();
-    if (!isUIntN(8 * getNumToSkipInBytes(), Delta))
-      PrintFatalError(
-          "disassembler decoding table too large, try --large-decoder-table");
+  using child_iterator = mapped_iterator<decltype(Children)::const_iterator,
+                                         decltype(&mapElement)>;
 
-    Data[FixupIdx] = static_cast<uint8_t>(Delta);
-    Data[FixupIdx + 1] = static_cast<uint8_t>(Delta >> 8);
-    if (getNumToSkipInBytes() == 3)
-      Data[FixupIdx + 2] = static_cast<uint8_t>(Delta >> 16);
+  child_iterator child_begin() const {
+    return child_iterator(Children.begin(), mapElement);
   }
 
-private:
-  std::vector<uint8_t> Data;
+  child_iterator child_end() const {
+    return child_iterator(Children.end(), mapElement);
+  }
+
+  iterator_range<child_iterator> children() const {
+    return make_range(child_begin(), child_end());
+  }
+};
+
+class SwitchFieldNode : public DecoderTreeNode {
+  unsigned StartBit;
+  unsigned NumBits;
+  std::map<uint64_t, std::unique_ptr<DecoderTreeNode>> Cases;
+
+  static std::pair<uint64_t, const DecoderTreeNode *>
+  mapElement(decltype(Cases)::const_reference Element) {
+    return std::pair(Element.first, Element.second.get());
+  }
+
+public:
+  SwitchFieldNode(unsigned StartBit, unsigned NumBits)
+      : DecoderTreeNode(SwitchField), StartBit(StartBit), NumBits(NumBits) {}
+
+  void addCase(uint64_t Value, std::unique_ptr<DecoderTreeNode> N) {
+    Cases.try_emplace(Value, std::move(N));
+  }
+
+  unsigned getStartBit() const { return StartBit; }
+
+  unsigned getNumBits() const { return NumBits; }
+
+  using case_iterator =
+      mapped_iterator<decltype(Cases)::const_iterator, decltype(&mapElement)>;
+
+  case_iterator case_begin() const {
+    return case_iterator(Cases.begin(), mapElement);
+  }
+
+  case_iterator case_end() const {
+    return case_iterator(Cases.end(), mapElement);
+  }
+
+  iterator_range<case_iterator> cases() const {
+    return make_range(case_begin(), case_end());
+  }
+};
+
+class CheckFieldNode : public DecoderTreeNode {
+  unsigned StartBit;
+  unsigned NumBits;
+  uint64_t Value;
+
+public:
+  CheckFieldNode(unsigned StartBit, unsigned NumBits, uint64_t Value)
+      : DecoderTreeNode(CheckField), StartBit(StartBit), NumBits(NumBits),
+        Value(Value) {}
+
+  unsigned getStartBit() const { return StartBit; }
+
+  unsigned getNumBits() const { return NumBits; }
+
+  uint64_t getValue() const { return Value; }
+};
+
+class CheckPredicateNode : public DecoderTreeNode {
+  unsigned Index;
+
+public:
+  explicit CheckPredicateNode(unsigned Index)
+      : DecoderTreeNode(CheckPredicate), Index(Index) {}
+
+  unsigned getPredicateIndex() const { return Index; }
+};
+
+class SoftFailNode : public DecoderTreeNode {
+  uint64_t PositiveMask, NegativeMask;
+
+public:
+  explicit SoftFailNode(uint64_t PositiveMask, uint64_t NegativeMask)
+      : DecoderTreeNode(SoftFail), PositiveMask(PositiveMask),
+        NegativeMask(NegativeMask) {}
+
+  uint64_t getPositiveMask() const { return PositiveMask; }
+  uint64_t getNegativeMask() const { return NegativeMask; }
+};
+
+class DecodeNode : public DecoderTreeNode {
+  unsigned EncodingID;
+  unsigned Index;
+
+public:
+  DecodeNode(unsigned EncodingID, unsigned Index)
+      : DecoderTreeNode(Decode), EncodingID(EncodingID), Index(Index) {}
+
+  unsigned getEncodingID() const { return EncodingID; }
+
+  unsigned getDecoderIndex() const { return Index; }
 };
 
+typedef SmallSetVector<CachedHashString, 16> PredicateSet;
+typedef SmallSetVector<CachedHashString, 16> DecoderSet;
+
 struct DecoderTableInfo {
-  DecoderTable Table;
   PredicateSet Predicates;
   DecoderSet Decoders;
 };
@@ -361,11 +466,10 @@ class DecoderEmitter {
 
   const CodeGenTarget &getTarget() const { return Target; }
 
-  // Emit the decoder state machine table. Returns a mask of MCD decoder ops
-  // that were emitted.
-  unsigned emitTable(formatted_raw_ostream &OS, DecoderTable &Table,
-                     StringRef Namespace, unsigned HwModeID, unsigned BitWidth,
-                     ArrayRef<unsigned> EncodingIDs) const;
+  // Emit the decoder state machine table.
+  void emitDecoderTable(formatted_raw_ostream &OS, const DecoderTreeNode *Tree,
+                        StringRef Namespace, unsigned HwModeID,
+                        unsigned BitWidth) const;
   void emitInstrLenTable(formatted_raw_ostream &OS,
                          ArrayRef<unsigned> InstrLen) const;
   void emitPredicateFunction(formatted_raw_ostream &OS,
@@ -476,7 +580,7 @@ enum bitAttr_t {
 
 class FilterChooser {
   // TODO: Unfriend by providing the necessary accessors.
-  friend class DecoderTableBuilder;
+  friend class DecoderTreeBuilder;
 
   // Vector of encodings to choose our filter.
   ArrayRef<InstructionEncoding> Encodings;
@@ -581,7 +685,8 @@ class FilterChooser {
   // This returns a list of undecoded bits of an instructions, for example,
   // Inst{20} = 1 && Inst{3-0} == 0b1111 represents two islands of yet-to-be
   // decoded bits in order to verify that the instruction matches the Opcode.
-  std::vector<Island> getIslands(const KnownBits &EncodingBits) const;
+  static std::vector<Island> getIslands(const KnownBits &EncodingBits,
+                                        const KnownBits &FilterBits);
 
   /// Scans the well-known encoding bits of the encodings and, builds up a list
   /// of candidate filters, and then returns the best one, if any.
@@ -600,49 +705,128 @@ class FilterChooser {
   void dump() const;
 };
 
-class DecoderTableBuilder {
+class DecoderTreeBuilder {
   const CodeGenTarget &Target;
   ArrayRef<InstructionEncoding> Encodings;
   DecoderTableInfo &TableInfo;
 
 public:
-  DecoderTableBuilder(const CodeGenTarget &Target,
-                      ArrayRef<InstructionEncoding> Encodings,
-                      DecoderTableInfo &TableInfo)
+  DecoderTreeBuilder(const CodeGenTarget &Target,
+                     ArrayRef<InstructionEncoding> Encodings,
+                     DecoderTableInfo &TableInfo)
       : Target(Target), Encodings(Encodings), TableInfo(TableInfo) {}
 
-  void buildTable(const FilterChooser &FC, unsigned BitWidth) const {
-    // When specializing decoders per bit width, each decoder table will begin
-    // with the bitwidth for that table.
-    if (SpecializeDecodersPerBitwidth)
-      TableInfo.Table.insertULEB128(BitWidth);
-    emitTableEntries(FC);
+  std::unique_ptr<DecoderTreeNode> buildTree(const FilterChooser &FC) {
+    return buildAnyOfNode(FC);
   }
 
 private:
-  void emitBinaryParser(raw_ostream &OS, indent Indent,
-                        const OperandInfo &OpInfo) const;
+  static void emitBinaryParser(raw_ostream &OS, indent Indent,
+                               const OperandInfo &OpInfo);
+  static void emitDecoder(raw_ostream &OS, indent Indent,
+                          const InstructionEncoding &Encoding);
+  unsigned getDecoderIndex(const InstructionEncoding &Encoding);
+
+  static bool emitPredicateMatchAux(StringRef PredicateNamespace,
+                                    const Init &Val, bool ParenIfBinOp,
+                                    raw_ostream &OS);
+  static bool emitPredicateMatch(StringRef PredicateNamespace, raw_ostream &OS,
+                                 const InstructionEncoding &Encoding);
+  unsigned getPredicateIndex(const InstructionEncoding &Encoding) const;
+
+  std::unique_ptr<DecoderTreeNode>
+  buildTerminalNode(unsigned EncodingID, const KnownBits &FilterBits);
+
+  std::unique_ptr<DecoderTreeNode> buildAllOfOrSwitchNode(
+      unsigned StartBit, unsigned NumBits,
+      const std::map<uint64_t, std::unique_ptr<const FilterChooser>> &FCMap);
+
+  std::unique_ptr<DecoderTreeNode> buildAnyOfNode(const FilterChooser &FC);
+};
+
+class DecoderTableEmitter {
+  const CodeGenTarget &Target;
+  ArrayRef<InstructionEncoding> Encodings;
+  formatted_raw_ostream OS;
+  unsigned IndexWidth;
+  unsigned CurrentIndex;
+  unsigned CommentIndex;
+  bool HasCheckPredicate = false;
+  bool HasSoftFail = false;
+  bool HasTryDecode = false;
+
+public:
+  DecoderTableEmitter(const CodeGenTarget &Target,
+                      ArrayRef<InstructionEncoding> Encodings, raw_ostream &OS)
+      : Target(Target), Encodings(Encodings), OS(OS) {}
+
+  void emitTable(StringRef TableName, unsigned BitWidth,
+                 const DecoderTreeNode *Root);
+
+  bool hasCheckPredicate() const { return HasCheckPredicate; }
 
-  void emitDecoder(raw_ostream &OS, indent Indent, unsigned EncodingID) const;
+  bool hasSoftFail() const { return HasSoftFail; }
 
-  unsigned getDecoderIndex(unsigned EncodingID) const;
+  bool hasTryDecode() const { return HasTryDecode; }
 
-  unsigned getPredicateIndex(StringRef P) const;
+private:
+  unsigned computeNodeSize(const DecoderTreeNode *N) const;
 
-  bool emitPredicateMatchAux(const Init &Val, bool ParenIfBinOp,
-                             raw_ostream &OS) const;
+  unsigned computeTableSize(unsigned BitWidth,
+                            const DecoderTreeNode *Root) const {
+    unsigned Size = 0;
+    if (SpecializeDecodersPerBitwidth)
+      Size = getULEB128Size(BitWidth);
+    Size += computeNodeSize(Root);
+    return Size;
+  }
 
-  bool emitPredicateMatch(raw_ostream &OS, unsigned EncodingID) const;
+  void emitStartLine() {
+    CommentIndex = CurrentIndex;
+    OS.indent(2);
+  }
 
-  bool doesOpcodeNeedPredicate(unsigned EncodingID) const;
+  void emitOpcode(StringRef Name) {
+    emitStartLine();
+    OS << "MCD::" << Name << ", ";
+    ++CurrentIndex;
+  }
 
-  void emitPredicateTableEntry(unsigned EncodingID) const;
+  void emitByte(uint8_t Val) {
+    OS << static_cast<unsigned>(Val) << ", ";
+    ++CurrentIndex;
+  }
 
-  void emitSoftFailTableEntry(unsigned EncodingID) const;
+  void emitUInt8(unsigned Val) {
+    assert(isUInt<8>(Val));
+    emitByte(Val);
+  }
 
-  void emitSingletonTableEntry(const FilterChooser &FC) const;
+  void emitULEB128(uint64_t Val) {
+    while (Val >= 0x80) {
+      emitByte((Val & 0x7F) | 0x80);
+      Val >>= 7;
+    }
+    emitByte(Val);
+  }
 
-  void emitTableEntries(const FilterChooser &FC) const;
+  formatted_raw_ostream &emitComment(indent Indent) {
+    constexpr unsigned CommentColumn = 45;
+    if (OS.getColumn() > CommentColumn)
+      OS << '\n';
+    OS.PadToColumn(CommentColumn);
+    OS << "// " << format_decimal(CommentIndex, IndexWidth) << ": " << Indent;
+    return OS;
+  }
+
+  void emitAnyOfNode(const AnyOfNode *N, indent Indent);
+  void emitAllOfNode(const AllOfNode *N, indent Indent);
+  void emitSwitchFieldNode(const SwitchFieldNode *N, indent Indent);
+  void emitCheckFieldNode(const CheckFieldNode *N, indent Indent);
+  void emitCheckPredicateNode(const CheckPredicateNode *N, indent Indent);
+  void emitSoftFailNode(const SoftFailNode *N, indent Indent);
+  void emitDecodeNode(const DecodeNode *N, indent Indent);
+  void emitNode(const DecoderTreeNode *N, indent Indent);
 };
 
 } // end anonymous namespace
@@ -715,204 +899,86 @@ unsigned Filter::usefulness() const {
   return FilteredIDs.size() + VariableIDs.empty();
 }
 
-//////////////////////////////////
-//                              //
-// Filterchooser Implementation //
-//                              //
-//////////////////////////////////
-
-// Emit the decoder state machine table. Returns a mask of MCD decoder ops
-// that were emitted.
-unsigned DecoderEmitter::emitTable(formatted_raw_ostream &OS,
-                                   DecoderTable &Table, StringRef Namespace,
-                                   unsigned HwModeID, unsigned BitWidth,
-                                   ArrayRef<unsigned> EncodingIDs) const {
-  // We'll need to be able to map from a decoded opcode into the corresponding
-  // EncodingID for this specific combination of BitWidth and Namespace. This
-  // is used below to index into Encodings.
-  DenseMap<unsigned, unsigned> OpcodeToEncodingID;
-  OpcodeToEncodingID.reserve(EncodingIDs.size());
-  for (unsigned EncodingID : EncodingIDs) {
-    const Record *InstDef = Encodings[EncodingID].getInstruction()->TheDef;
-    OpcodeToEncodingID[Target.getInstrIntValue(InstDef)] = EncodingID;
-  }
-
-  OS << "static const uint8_t DecoderTable" << Namespace;
-  if (HwModeID != DefaultMode)
-    OS << '_' << Target.getHwModes().getModeName(HwModeID);
-  OS << BitWidth << "[" << Table.size() << "] = {\n";
-
-  // Emit ULEB128 encoded value to OS, returning the number of bytes emitted.
-  auto emitULEB128 = [](DecoderTable::const_iterator &I,
-                        formatted_raw_ostream &OS) {
-    while (*I >= 128)
-      OS << (unsigned)*I++ << ", ";
-    OS << (unsigned)*I++ << ", ";
-  };
+static bool doesOpcodeNeedPredicate(const InstructionEncoding &Encoding);
 
-  // Emit `getNumToSkipInBytes()`-byte numtoskip value to OS, returning the
-  // NumToSkip value.
-  auto emitNumToSkip = [](DecoderTable::const_iterator &I,
-                          formatted_raw_ostream &OS) {
-    uint8_t Byte = *I++;
-    uint32_t NumToSkip = Byte;
-    OS << (unsigned)Byte << ", ";
-    Byte = *I++;
-    OS << (unsigned)Byte << ", ";
-    NumToSkip |= Byte << 8;
-    if (getNumToSkipInBytes() == 3) {
-      Byte = *I++;
-      OS << (unsigned)(Byte) << ", ";
-      NumToSkip |= Byte << 16;
-    }
-    return NumToSkip;
-  };
-
-  // FIXME: We may be able to use the NumToSkip values to recover
-  // appropriate indentation levels.
-  DecoderTable::const_iterator I = Table.begin();
-  DecoderTable::const_iterator E = Table.end();
-  const uint8_t *const EndPtr = Table.data() + Table.size();
-
-  auto emitNumToSkipComment = [&](uint32_t NumToSkip, bool InComment = false) {
-    uint32_t Index = ((I - Table.begin()) + NumToSkip);
-    OS << (InComment ? ", " : "// ");
-    OS << "Skip to: " << Index;
-  };
+std::unique_ptr<DecoderTreeNode>
+DecoderTreeBuilder::buildTerminalNode(unsigned EncodingID,
+                                      const KnownBits &FilterBits) {
+  const InstructionEncoding &Encoding = Encodings[EncodingID];
+  auto N = std::make_unique<AllOfNode>();
 
-  // The first entry when specializing decoders per bitwidth is the bitwidth.
-  // This will be used for additional checks in `decodeInstruction`.
-  if (SpecializeDecodersPerBitwidth) {
-    OS << "/* 0  */";
-    OS.PadToColumn(14);
-    emitULEB128(I, OS);
-    OS << " // Bitwidth " << BitWidth << '\n';
+  if (doesOpcodeNeedPredicate(Encoding)) {
+    unsigned PredicateIndex = getPredicateIndex(Encoding);
+    N->addChild(std::make_unique<CheckPredicateNode>(PredicateIndex));
   }
 
-  unsigned OpcodeMask = 0;
+  std::vector<FilterChooser::Island> Islands =
+      FilterChooser::getIslands(Encoding.getMandatoryBits(), FilterBits);
+  for (const FilterChooser::Island &Ilnd : reverse(Islands)) {
+    N->addChild(std::make_unique<CheckFieldNode>(Ilnd.StartBit, Ilnd.NumBits,
+                                                 Ilnd.FieldVal));
+  }
 
-  while (I != E) {
-    assert(I < E && "incomplete decode table entry!");
+  const KnownBits &InstBits = Encoding.getInstBits();
+  const APInt &SoftFailMask = Encoding.getSoftFailMask();
+  if (!SoftFailMask.isZero()) {
+    APInt PositiveMask = InstBits.Zero & SoftFailMask;
+    APInt NegativeMask = InstBits.One & SoftFailMask;
+    N->addChild(std::make_unique<SoftFailNode>(PositiveMask.getZExtValue(),
+                                               NegativeMask.getZExtValue()));
+  }
 
-    uint64_t Pos = I - Table.begin();
-    OS << "/* " << Pos << " */";
-    OS.PadToColumn(12);
+  unsigned DecoderIndex = getDecoderIndex(Encoding);
+  N->addChild(std::make_unique<DecodeNode>(EncodingID, DecoderIndex));
 
-    const uint8_t DecoderOp = *I++;
-    OpcodeMask |= (1 << DecoderOp);
-    switch (DecoderOp) {
-    default:
-      PrintFatalError("Invalid decode table opcode: " + Twine((int)DecoderOp) +
-                      " at index " + Twine(Pos));
-    case MCD::OPC_Scope: {
-      OS << "  MCD::OPC_Scope, ";
-      uint32_t NumToSkip = emitNumToSkip(I, OS);
-      emitNumToSkipComment(NumToSkip);
-      OS << '\n';
-      break;
-    }
-    case MCD::OPC_ExtractField: {
-      OS << "  MCD::OPC_ExtractField, ";
-
-      // ULEB128 encoded start value.
-      const char *ErrMsg = nullptr;
-      unsigned Start = decodeULEB128(&*I, nullptr, EndPtr, &ErrMsg);
-      assert(ErrMsg == nullptr && "ULEB128 value too large!");
-      emitULEB128(I, OS);
-
-      unsigned Len = *I++;
-      OS << Len << ",  // Inst{";
-      if (Len > 1)
-        OS << (Start + Len - 1) << "-";
-      OS << Start << "} ...\n";
-      break;
-    }
-    case MCD::OPC_FilterValueOrSkip: {
-      OS << "  MCD::OPC_FilterValueOrSkip, ";
-      // The filter value is ULEB128 encoded.
-      emitULEB128(I, OS);
-      uint32_t NumToSkip = emitNumToSkip(I, OS);
-      emitNumToSkipComment(NumToSkip);
-      OS << '\n';
-      break;
-    }
-    case MCD::OPC_FilterValue: {
-      OS << "  MCD::OPC_FilterValue, ";
-      // The filter value is ULEB128 encoded.
-      emitULEB128(I, OS);
-      OS << '\n';
-      break;
-    }
-    case MCD::OPC_CheckField: {
-      OS << "  MCD::OPC_CheckField, ";
-      // ULEB128 encoded start value.
-      emitULEB128(I, OS);
-      // 8-bit length.
-      unsigned Len = *I++;
-      OS << Len << ", ";
-      // ULEB128 encoded field value.
-      emitULEB128(I, OS);
-      OS << '\n';
-      break;
-    }
-    case MCD::OPC_CheckPredicate: {
-      OS << "  MCD::OPC_CheckPredicate, ";
-      emitULEB128(I, OS);
-      OS << '\n';
-      break;
-    }
-    case MCD::OPC_Decode:
-    case MCD::OPC_TryDecode: {
-      bool IsTry = DecoderOp == MCD::OPC_TryDecode;
-      // Decode the Opcode value.
-      const char *ErrMsg = nullptr;
-      unsigned Opc = decodeULEB128(&*I, nullptr, EndPtr, &ErrMsg);
-      assert(ErrMsg == nullptr && "ULEB128 value too large!");
+  return N;
+}
 
-      OS << "  MCD::OPC_" << (IsTry ? "Try" : "") << "Decode, ";
-      emitULEB128(I, OS);
+std::unique_ptr<DecoderTreeNode> DecoderTreeBuilder::buildAllOfOrSwitchNode(
+    unsigned int StartBit, unsigned int NumBits,
+    const std::map<uint64_t, std::unique_ptr<const FilterChooser>> &FCMap) {
+  if (FCMap.size() == 1) {
+    const auto &[FieldVal, ChildFC] = *FCMap.begin();
+    auto N = std::make_unique<AllOfNode>();
+    N->addChild(std::make_unique<CheckFieldNode>(StartBit, NumBits, FieldVal));
+    N->addChild(buildAnyOfNode(*ChildFC));
+    return N;
+  }
+  auto N = std::make_unique<SwitchFieldNode>(StartBit, NumBits);
+  for (const auto &[FieldVal, ChildFC] : FCMap)
+    N->addCase(FieldVal, buildAnyOfNode(*ChildFC));
+  return N;
+}
 
-      // Decoder index.
-      unsigned DecodeIdx = decodeULEB128(&*I, nullptr, EndPtr, &ErrMsg);
-      assert(ErrMsg == nullptr && "ULEB128 value too large!");
-      emitULEB128(I, OS);
+std::unique_ptr<DecoderTreeNode>
+DecoderTreeBuilder::buildAnyOfNode(const FilterChooser &FC) {
+  auto N = std::make_unique<AnyOfNode>();
+  if (FC.SingletonEncodingID) {
+    N->addChild(buildTerminalNode(*FC.SingletonEncodingID, FC.FilterBits));
+  } else {
+    N->addChild(
+        buildAllOfOrSwitchNode(FC.StartBit, FC.NumBits, FC.FilterChooserMap));
+  }
+  if (FC.VariableFC) {
+    N->addChild(buildAnyOfNode(*FC.VariableFC));
+  }
 
-      auto EncI = OpcodeToEncodingID.find(Opc);
-      assert(EncI != OpcodeToEncodingID.end() && "no encoding entry");
-      auto EncodingID = EncI->second;
+  return N;
+}
 
-      if (!IsTry) {
-        OS << "// Opcode: " << Encodings[EncodingID].getName()
-           << ", DecodeIdx: " << DecodeIdx << '\n';
-        break;
-      }
-      OS << '\n';
-      break;
-    }
-    case MCD::OPC_SoftFail: {
-      OS << "  MCD::OPC_SoftFail, ";
-      // Decode the positive mask.
-      const char *ErrMsg = nullptr;
-      uint64_t PositiveMask = decodeULEB128(&*I, nullptr, EndPtr, &ErrMsg);
-      assert(ErrMsg == nullptr && "ULEB128 value too large!");
-      emitULEB128(I, OS);
-
-      // Decode the negative mask.
-      uint64_t NegativeMask = decodeULEB128(&*I, nullptr, EndPtr, &ErrMsg);
-      assert(ErrMsg == nullptr && "ULEB128 value too large!");
-      emitULEB128(I, OS);
-      OS << "// +ve mask: 0x";
-      OS.write_hex(PositiveMask);
-      OS << ", -ve mask: 0x";
-      OS.write_hex(NegativeMask);
-      OS << '\n';
-      break;
-    }
-    }
-  }
-  OS << "};\n\n";
+// Emit the decoder state machine table.
+void DecoderEmitter::emitDecoderTable(formatted_raw_ostream &OS,
+                                      const DecoderTreeNode *Tree,
+                                      StringRef Namespace, unsigned HwModeID,
+                                      unsigned BitWidth) const {
+  SmallString<32> TableName("DecoderTable");
+  TableName.append(Namespace);
+  if (HwModeID != DefaultMode)
+    TableName.append({"_", Target.getHwModes().getModeName(HwModeID)});
+  TableName.append(std::to_string(BitWidth));
 
-  return OpcodeMask;
+  DecoderTableEmitter TableEmitter(Target, Encodings, OS);
+  TableEmitter.emitTable(TableName, BitWidth, Tree);
 }
 
 void DecoderEmitter::emitInstrLenTable(formatted_raw_ostream &OS,
@@ -1042,7 +1108,8 @@ void FilterChooser::dumpStack(raw_ostream &OS, indent Indent,
 // Inst{20} = 1 && Inst{3-0} == 0b1111 represents two islands of yet-to-be
 // decoded bits in order to verify that the instruction matches the Opcode.
 std::vector<FilterChooser::Island>
-FilterChooser::getIslands(const KnownBits &EncodingBits) const {
+FilterChooser::getIslands(const KnownBits &EncodingBits,
+                          const KnownBits &FilterBits) {
   std::vector<Island> Islands;
   uint64_t FieldVal;
   unsigned StartBit;
@@ -1055,7 +1122,7 @@ FilterChooser::getIslands(const KnownBits &EncodingBits) const {
   unsigned FilterWidth = FilterBits.getBitWidth();
   for (unsigned i = 0; i != FilterWidth; ++i) {
     bool IsKnown = EncodingBits.Zero[i] || EncodingBits.One[i];
-    bool Filtered = isPositionFiltered(i);
+    bool Filtered = FilterBits.Zero[i] || FilterBits.One[i];
     switch (State) {
     default:
       llvm_unreachable("Unreachable code!");
@@ -1088,8 +1155,8 @@ FilterChooser::getIslands(const KnownBits &EncodingBits) const {
   return Islands;
 }
 
-void DecoderTableBuilder::emitBinaryParser(raw_ostream &OS, indent Indent,
-                                           const OperandInfo &OpInfo) const {
+void DecoderTreeBuilder::emitBinaryParser(raw_ostream &OS, indent Indent,
+                                          const OperandInfo &OpInfo) {
   // Special case for 'bits<0>'.
   if (OpInfo.Fields.empty() && !OpInfo.InitValue) {
     if (IgnoreNonDecodableOperands)
@@ -1139,10 +1206,8 @@ void DecoderTableBuilder::emitBinaryParser(raw_ostream &OS, indent Indent,
   }
 }
 
-void DecoderTableBuilder::emitDecoder(raw_ostream &OS, indent Indent,
-                                      unsigned EncodingID) const {
-  const InstructionEncoding &Encoding = Encodings[EncodingID];
-
+void DecoderTreeBuilder::emitDecoder(raw_ostream &OS, indent Indent,
+                                     const InstructionEncoding &Encoding) {
   // If a custom instruction decoder was specified, use that.
   StringRef DecoderMethod = Encoding.getDecoderMethod();
   if (!DecoderMethod.empty()) {
@@ -1157,14 +1222,15 @@ void DecoderTableBuilder::emitDecoder(raw_ostream &OS, indent Indent,
     emitBinaryParser(OS, Indent, Op);
 }
 
-unsigned DecoderTableBuilder::getDecoderIndex(unsigned EncodingID) const {
+unsigned
+DecoderTreeBuilder::getDecoderIndex(const InstructionEncoding &Encoding) {
   // Build up the predicate string.
   SmallString<256> Decoder;
   // FIXME: emitDecoder() function can take a buffer directly rather than
   // a stream.
   raw_svector_ostream S(Decoder);
   indent Indent(UseFnTableInDecodeToMCInst ? 2 : 4);
-  emitDecoder(S, Indent, EncodingID);
+  emitDecoder(S, Indent, Encoding);
 
   // Using the full decoder string as the key value here is a bit
   // heavyweight, but is effective. If the string comparisons become a
@@ -1181,20 +1247,21 @@ unsigned DecoderTableBuilder::getDecoderIndex(unsigned EncodingID) const {
 }
 
 // If ParenIfBinOp is true, print a surrounding () if Val uses && or ||.
-bool DecoderTableBuilder::emitPredicateMatchAux(const Init &Val,
-                                                bool ParenIfBinOp,
-                                                raw_ostream &OS) const {
+bool DecoderTreeBuilder::emitPredicateMatchAux(StringRef PredicateNamespace,
+                                               const Init &Val,
+                                               bool ParenIfBinOp,
+                                               raw_ostream &OS) {
   if (const auto *D = dyn_cast<DefInit>(&Val)) {
     if (!D->getDef()->isSubClassOf("SubtargetFeature"))
       return true;
-    OS << "Bits[" << Target.getName() << "::" << D->getAsString() << "]";
+    OS << "Bits[" << PredicateNamespace << "::" << D->getAsString() << "]";
     return false;
   }
   if (const auto *D = dyn_cast<DagInit>(&Val)) {
     std::string Op = D->getOperator()->getAsString();
     if (Op == "not" && D->getNumArgs() == 1) {
       OS << '!';
-      return emitPredicateMatchAux(*D->getArg(0), true, OS);
+      return emitPredicateMatchAux(PredicateNamespace, *D->getArg(0), true, OS);
     }
     if ((Op == "any_of" || Op == "all_of") && D->getNumArgs() > 0) {
       bool Paren = D->getNumArgs() > 1 && std::exchange(ParenIfBinOp, true);
@@ -1203,7 +1270,7 @@ bool DecoderTableBuilder::emitPredicateMatchAux(const Init &Val,
       ListSeparator LS(Op == "any_of" ? " || " : " && ");
       for (auto *Arg : D->getArgs()) {
         OS << LS;
-        if (emitPredicateMatchAux(*Arg, ParenIfBinOp, OS))
+        if (emitPredicateMatchAux(PredicateNamespace, *Arg, ParenIfBinOp, OS))
           return true;
       }
       if (Paren)
@@ -1214,10 +1281,11 @@ bool DecoderTableBuilder::emitPredicateMatchAux(const Init &Val,
   return true;
 }
 
-bool DecoderTableBuilder::emitPredicateMatch(raw_ostream &OS,
-                                             unsigned EncodingID) const {
+bool DecoderTreeBuilder::emitPredicateMatch(
+    StringRef PredicateNamespace, raw_ostream &OS,
+    const InstructionEncoding &Encoding) {
   const ListInit *Predicates =
-      Encodings[EncodingID].getRecord()->getValueAsListInit("Predicates");
+      Encoding.getRecord()->getValueAsListInit("Predicates");
   bool IsFirstEmission = true;
   for (unsigned i = 0; i < Predicates->size(); ++i) {
     const Record *Pred = Predicates->getElementAsRecord(i);
@@ -1229,7 +1297,8 @@ bool DecoderTableBuilder::emitPredicateMatch(raw_ostream &OS,
 
     if (!IsFirstEmission)
       OS << " && ";
-    if (emitPredicateMatchAux(*Pred->getValueAsDag("AssemblerCondDag"),
+    if (emitPredicateMatchAux(PredicateNamespace,
+                              *Pred->getValueAsDag("AssemblerCondDag"),
                               Predicates->size() > 1, OS))
       PrintFatalError(Pred->getLoc(), "Invalid AssemblerCondDag!");
     IsFirstEmission = false;
@@ -1237,9 +1306,9 @@ bool DecoderTableBuilder::emitPredicateMatch(raw_ostream &OS,
   return !Predicates->empty();
 }
 
-bool DecoderTableBuilder::doesOpcodeNeedPredicate(unsigned EncodingID) const {
+static bool doesOpcodeNeedPredicate(const InstructionEncoding &Encoding) {
   const ListInit *Predicates =
-      Encodings[EncodingID].getRecord()->getValueAsListInit("Predicates");
+      Encoding.getRecord()->getValueAsListInit("Predicates");
   for (unsigned i = 0; i < Predicates->size(); ++i) {
     const Record *Pred = Predicates->getElementAsRecord(i);
     if (!Pred->getValue("AssemblerMatcherPredicate"))
@@ -1251,7 +1320,15 @@ bool DecoderTableBuilder::doesOpcodeNeedPredicate(unsigned EncodingID) const {
   return false;
 }
 
-unsigned DecoderTableBuilder::getPredicateIndex(StringRef Predicate) const {
+unsigned DecoderTreeBuilder::getPredicateIndex(
+    const InstructionEncoding &Encoding) const {
+  // Build up the predicate string.
+  SmallString<256> Predicate;
+  // FIXME: emitPredicateMatch() functions can take a buffer directly rather
+  // than a stream.
+  raw_svector_ostream PS(Predicate);
+  emitPredicateMatch(Target.getName(), PS, Encoding);
+
   // Using the full predicate string as the key value here is a bit
   // heavyweight, but is effective. If the string comparisons become a
   // performance concern, we can implement a mangling of the predicate
@@ -1259,88 +1336,11 @@ unsigned DecoderTableBuilder::getPredicateIndex(StringRef Predicate) const {
   // overkill for now, though.
 
   // Make sure the predicate is in the table.
-  TableInfo.Predicates.insert(CachedHashString(Predicate));
+  PredicateSet &Predicates = TableInfo.Predicates;
+  Predicates.insert(CachedHashString(Predicate));
   // Now figure out the index for when we write out the table.
-  PredicateSet::const_iterator P = find(TableInfo.Predicates, Predicate);
-  return (unsigned)(P - TableInfo.Predicates.begin());
-}
-
-void DecoderTableBuilder::emitPredicateTableEntry(unsigned EncodingID) const {
-  if (!doesOpcodeNeedPredicate(EncodingID))
-    return;
-
-  // Build up the predicate string.
-  SmallString<256> Predicate;
-  // FIXME: emitPredicateMatch() functions can take a buffer directly rather
-  // than a stream.
-  raw_svector_ostream PS(Predicate);
-  emitPredicateMatch(PS, EncodingID);
-
-  // Figure out the index into the predicate table for the predicate just
-  // computed.
-  unsigned PIdx = getPredicateIndex(PS.str());
-
-  TableInfo.Table.insertOpcode(MCD::OPC_CheckPredicate);
-  TableInfo.Table.insertULEB128(PIdx);
-}
-
-void DecoderTableBuilder::emitSoftFailTableEntry(unsigned EncodingID) const {
-  const InstructionEncoding &Encoding = Encodings[EncodingID];
-  const KnownBits &InstBits = Encoding.getInstBits();
-  const APInt &SoftFailMask = Encoding.getSoftFailMask();
-
-  if (SoftFailMask.isZero())
-    return;
-
-  APInt PositiveMask = InstBits.Zero & SoftFailMask;
-  APInt NegativeMask = InstBits.One & SoftFailMask;
-
-  TableInfo.Table.insertOpcode(MCD::OPC_SoftFail);
-  TableInfo.Table.insertULEB128(PositiveMask.getZExtValue());
-  TableInfo.Table.insertULEB128(NegativeMask.getZExtValue());
-}
-
-// Emits table entries to decode the singleton.
-void DecoderTableBuilder::emitSingletonTableEntry(
-    const FilterChooser &FC) const {
-  unsigned EncodingID = *FC.SingletonEncodingID;
-  const InstructionEncoding &Encoding = Encodings[EncodingID];
-  KnownBits EncodingBits = Encoding.getMandatoryBits();
-
-  // Look for islands of undecoded bits of the singleton.
-  std::vector<FilterChooser::Island> Islands = FC.getIslands(EncodingBits);
-
-  // Emit the predicate table entry if one is needed.
-  emitPredicateTableEntry(EncodingID);
-
-  // Check any additional encoding fields needed.
-  for (const FilterChooser::Island &Ilnd : reverse(Islands)) {
-    TableInfo.Table.insertOpcode(MCD::OPC_CheckField);
-    TableInfo.Table.insertULEB128(Ilnd.StartBit);
-    TableInfo.Table.insertUInt8(Ilnd.NumBits);
-    TableInfo.Table.insertULEB128(Ilnd.FieldVal);
-  }
-
-  // Check for soft failure of the match.
-  emitSoftFailTableEntry(EncodingID);
-
-  unsigned DIdx = getDecoderIndex(EncodingID);
-
-  // Produce OPC_Decode or OPC_TryDecode opcode based on the information
-  // whether the instruction decoder is complete or not. If it is complete
-  // then it handles all possible values of remaining variable/unfiltered bits
-  // and for any value can determine if the bitpattern is a valid instruction
-  // or not. This means OPC_Decode will be the final step in the decoding
-  // process. If it is not complete, then the Fail return code from the
-  // decoder method indicates that additional processing should be done to see
-  // if there is any other instruction that also matches the bitpattern and
-  // can decode it.
-  const MCD::DecoderOps DecoderOp =
-      Encoding.hasCompleteDecoder() ? MCD::OPC_Decode : MCD::OPC_TryDecode;
-  TableInfo.Table.insertOpcode(DecoderOp);
-  const Record *InstDef = Encodings[EncodingID].getInstruction()->TheDef;
-  TableInfo.Table.insertULEB128(Target.getInstrIntValue(InstDef));
-  TableInfo.Table.insertULEB128(DIdx);
+  PredicateSet::const_iterator P = find(Predicates, Predicate);
+  return std::distance(Predicates.begin(), P);
 }
 
 std::unique_ptr<Filter>
@@ -1358,7 +1358,7 @@ FilterChooser::findBestFilter(ArrayRef<bitAttr_t> BitAttrs, bool AllowMixed,
       KnownBits EncodingBits = Encoding.getMandatoryBits();
 
       // Look for islands of undecoded bits of any instruction.
-      std::vector<Island> Islands = getIslands(EncodingBits);
+      std::vector<Island> Islands = getIslands(EncodingBits, FilterBits);
       if (!Islands.empty()) {
         // Found an instruction with island(s).  Now just assign a filter.
         return std::make_unique<Filter>(
@@ -1618,67 +1618,227 @@ void FilterChooser::dump() const {
   }
 }
 
-void DecoderTableBuilder::emitTableEntries(const FilterChooser &FC) const {
-  DecoderTable &Table = TableInfo.Table;
-
-  // If there are other encodings that could match if those with all bits
-  // known don't, enter a scope so that they have a chance.
-  size_t FixupLoc = 0;
-  if (FC.VariableFC) {
-    Table.insertOpcode(MCD::OPC_Scope);
-    FixupLoc = Table.insertNumToSkip();
+unsigned DecoderTableEmitter::computeNodeSize(const DecoderTreeNode *N) const {
+  switch (N->getKind()) {
+  case DecoderTreeNode::AnyOf: {
+    const auto *CheckAny = static_cast<const AnyOfNode *>(N);
+    unsigned Size = 0;
+    for (const DecoderTreeNode *Child : drop_end(CheckAny->children())) {
+      unsigned ChildSize = computeNodeSize(Child);
+      Size += 1 + getULEB128Size(ChildSize) + ChildSize;
+    }
+    return Size + computeNodeSize(*std::prev(CheckAny->child_end()));
+  }
+  case DecoderTreeNode::AllOf: {
+    const auto *CheckAll = static_cast<const AllOfNode *>(N);
+    unsigned Size = 0;
+    for (const DecoderTreeNode *Child : CheckAll->children())
+      Size += computeNodeSize(Child);
+    return Size;
   }
+  case DecoderTreeNode::CheckField: {
+    const auto *CheckField = static_cast<const CheckFieldNode *>(N);
+    return 1 + getULEB128Size(CheckField->getStartBit()) + 1 +
+           getULEB128Size(CheckField->getValue());
+  }
+  case DecoderTreeNode::SwitchField: {
+    const auto *SwitchN = static_cast<const SwitchFieldNode *>(N);
+    unsigned Size = 1 + getULEB128Size(SwitchN->getStartBit()) + 1;
 
-  if (FC.SingletonEncodingID) {
-    assert(FC.FilterChooserMap.empty());
-    // There is only one encoding in which all bits in the filtered range are
-    // fully defined, but we still need to check if the remaining (unfiltered)
-    // bits are valid for this encoding. We also need to check predicates etc.
-    emitSingletonTableEntry(FC);
-  } else if (FC.FilterChooserMap.size() == 1) {
-    // If there is only one possible field value, emit a combined OPC_CheckField
-    // instead of OPC_ExtractField + OPC_FilterValue.
-    const auto &[FilterVal, Delegate] = *FC.FilterChooserMap.begin();
-    Table.insertOpcode(MCD::OPC_CheckField);
-    Table.insertULEB128(FC.StartBit);
-    Table.insertUInt8(FC.NumBits);
-    Table.insertULEB128(FilterVal);
-
-    // Emit table entries for the only case.
-    emitTableEntries(*Delegate);
-  } else {
-    // The general case: emit a switch over the field value.
-    Table.insertOpcode(MCD::OPC_ExtractField);
-    Table.insertULEB128(FC.StartBit);
-    Table.insertUInt8(FC.NumBits);
-
-    // Emit switch cases for all but the last element.
-    for (const auto &[FilterVal, Delegate] : drop_end(FC.FilterChooserMap)) {
-      Table.insertOpcode(MCD::OPC_FilterValueOrSkip);
-      Table.insertULEB128(FilterVal);
-      size_t FixupPos = Table.insertNumToSkip();
-
-      // Emit table entries for this case.
-      emitTableEntries(*Delegate);
-
-      // Patch the previous FilterValueOrSkip to fall through to the next case.
-      Table.patchNumToSkip(FixupPos, Table.size());
+    for (auto [Val, Child] : drop_end(SwitchN->cases())) {
+      unsigned ChildSize = computeNodeSize(Child);
+      Size += getULEB128Size(Val) + getULEB128Size(ChildSize) + ChildSize;
     }
 
-    // Emit a switch case for the last element. It never falls through;
-    // if it doesn't match, we leave the current scope.
-    const auto &[FilterVal, Delegate] = *FC.FilterChooserMap.rbegin();
-    Table.insertOpcode(MCD::OPC_FilterValue);
-    Table.insertULEB128(FilterVal);
+    auto [Val, Child] = *std::prev(SwitchN->case_end());
+    unsigned ChildSize = computeNodeSize(Child);
+    Size += getULEB128Size(Val) + getULEB128Size(0) + ChildSize;
+    return Size;
+  }
+  case DecoderTreeNode::CheckPredicate: {
+    const auto *CheckPredicate = static_cast<const CheckPredicateNode *>(N);
+    return 1 + getULEB128Size(CheckPredicate->getPredicateIndex());
+  }
+  case DecoderTreeNode::SoftFail: {
+    const auto *SoftFail = static_cast<const SoftFailNode *>(N);
+    return 1 + getULEB128Size(SoftFail->getPositiveMask()) +
+           getULEB128Size(SoftFail->getNegativeMask());
+  }
+  case DecoderTreeNode::Decode: {
+    const auto *Decode = static_cast<const DecodeNode *>(N);
+    const InstructionEncoding &Encoding = Encodings[Decode->getEncodingID()];
+    const Record *InstDef = Encoding.getInstruction()->TheDef;
+    unsigned InstOpcode = Target.getInstrIntValue(InstDef);
+    return 1 + getULEB128Size(InstOpcode) +
+           getULEB128Size(Decode->getDecoderIndex());
+  }
+  }
+  llvm_unreachable("Unknown node kind");
+}
+
+void DecoderTableEmitter::emitAnyOfNode(const AnyOfNode *N, indent Indent) {
+  for (const DecoderTreeNode *Child : drop_end(N->children())) {
+    emitOpcode("OPC_Scope");
+    emitULEB128(computeNodeSize(Child));
 
-    // Emit table entries for the last case.
-    emitTableEntries(*Delegate);
+    emitComment(Indent) << "{\n";
+    emitNode(Child, Indent + 1);
+    emitComment(Indent) << "}\n";
   }
 
-  if (FC.VariableFC) {
-    Table.patchNumToSkip(FixupLoc, Table.size());
-    emitTableEntries(*FC.VariableFC);
+  const DecoderTreeNode *Child = *std::prev(N->child_end());
+  emitNode(Child, Indent);
+}
+
+void DecoderTableEmitter::emitAllOfNode(const AllOfNode *N, indent Indent) {
+  for (const DecoderTreeNode *Child : N->children())
+    emitNode(Child, Indent);
+}
+
+void DecoderTableEmitter::emitSwitchFieldNode(const SwitchFieldNode *N,
+                                              indent Indent) {
+  unsigned LSB = N->getStartBit();
+  unsigned Width = N->getNumBits();
+  unsigned MSB = LSB + Width - 1;
+
+  emitOpcode("OPC_SwitchField");
+  emitULEB128(LSB);
+  emitUInt8(Width);
+
+  emitComment(Indent) << "switch Inst[" << MSB << ':' << LSB << "] {\n";
+
+  for (auto [Val, Child] : drop_end(N->cases())) {
+    emitStartLine();
+    emitULEB128(Val);
+    emitULEB128(computeNodeSize(Child));
+
+    emitComment(Indent) << "case " << format_hex(Val, 0) << ": {\n";
+    emitNode(Child, Indent + 1);
+    emitComment(Indent) << "}\n";
+  }
+
+  auto [Val, Child] = *std::prev(N->case_end());
+  emitStartLine();
+  emitULEB128(Val);
+  emitULEB128(0);
+
+  emitComment(Indent) << "case " << format_hex(Val, 0) << ": {\n";
+  emitNode(Child, Indent + 1);
+  emitComment(Indent) << "}\n";
+
+  emitComment(Indent) << "} // switch Inst[" << MSB << ':' << LSB << "]\n";
+}
+
+void DecoderTableEmitter::emitCheckFieldNode(const CheckFieldNode *N,
+                                             indent Indent) {
+  unsigned LSB = N->getStartBit();
+  unsigned Width = N->getNumBits();
+  unsigned MSB = LSB + Width - 1;
+  uint64_t Val = N->getValue();
+
+  emitOpcode("OPC_CheckField");
+  emitULEB128(LSB);
+  emitUInt8(Width);
+  emitULEB128(Val);
+
+  emitComment(Indent);
+  OS << "check Inst[" << MSB << ':' << LSB << "] == " << format_hex(Val, 0)
+     << '\n';
+}
+
+void DecoderTableEmitter::emitCheckPredicateNode(const CheckPredicateNode *N,
+                                                 indent Indent) {
+  unsigned PredicateIndex = N->getPredicateIndex();
+
+  emitOpcode("OPC_CheckPredicate");
+  emitULEB128(PredicateIndex);
+
+  emitComment(Indent) << "check predicate " << PredicateIndex << "\n";
+}
+
+void DecoderTableEmitter::emitSoftFailNode(const SoftFailNode *N,
+                                           indent Indent) {
+  uint64_t PositiveMask = N->getPositiveMask();
+  uint64_t NegativeMask = N->getNegativeMask();
+
+  emitOpcode("OPC_SoftFail");
+  emitULEB128(PositiveMask);
+  emitULEB128(NegativeMask);
+
+  emitComment(Indent) << "check softfail";
+  OS << " pos=" << format_hex(PositiveMask, 10);
+  OS << " neg=" << format_hex(NegativeMask, 10) << '\n';
+}
+
+void DecoderTableEmitter::emitDecodeNode(const DecodeNode *N, indent Indent) {
+  const InstructionEncoding &Encoding = Encodings[N->getEncodingID()];
+  const Record *InstDef = Encoding.getInstruction()->TheDef;
+  unsigned InstOpcode = Target.getInstrIntValue(InstDef);
+  unsigned DecoderIndex = N->getDecoderIndex();
+
+  emitOpcode(Encoding.hasCompleteDecoder() ? "OPC_Decode" : "OPC_TryDecode");
+  emitULEB128(InstOpcode);
+  emitULEB128(DecoderIndex);
+
+  emitComment(Indent);
+  if (!Encoding.hasCompleteDecoder())
+    OS << "try ";
+  OS << "decode to " << Encoding.getName() << " using decoder " << DecoderIndex
+     << '\n';
+}
+
+void DecoderTableEmitter::emitNode(const DecoderTreeNode *N, indent Indent) {
+  switch (N->getKind()) {
+  case DecoderTreeNode::AnyOf:
+    return emitAnyOfNode(static_cast<const AnyOfNode *>(N), Indent);
+  case DecoderTreeNode::AllOf:
+    return emitAllOfNode(static_cast<const AllOfNode *>(N), Indent);
+  case DecoderTreeNode::SwitchField:
+    return emitSwitchFieldNode(static_cast<const SwitchFieldNode *>(N), Indent);
+  case DecoderTreeNode::CheckField:
+    return emitCheckFieldNode(static_cast<const CheckFieldNode *>(N), Indent);
+  case DecoderTreeNode::CheckPredicate:
+    HasCheckPredicate = true;
+    return emitCheckPredicateNode(static_cast<const CheckPredicateNode *>(N),
+                                  Indent);
+  case DecoderTreeNode::SoftFail:
+    HasSoftFail = true;
+    return emitSoftFailNode(static_cast<const SoftFailNode *>(N), Indent);
+  case DecoderTreeNode::Decode:
+    HasTryDecode |=
+        Encodings[static_cast<const DecodeNode *>(N)->getEncodingID()]
+            .hasCompleteDecoder();
+    return emitDecodeNode(static_cast<const DecodeNode *>(N), Indent);
+  }
+  llvm_unreachable("Unknown node kind");
+}
+
+void DecoderTableEmitter::emitTable(StringRef TableName, unsigned BitWidth,
+                                    const DecoderTreeNode *Root) {
+  unsigned TableSize = computeTableSize(BitWidth, Root);
+  OS << "static const uint8_t " << TableName << "[" << TableSize << "] = {\n";
+
+  // Calculate the number of decimal places for table indices.
+  // This is simply log10 of the table size.
+  IndexWidth = 1;
+  for (unsigned S = TableSize; S /= 10;)
+    ++IndexWidth;
+
+  CurrentIndex = 0;
+
+  // When specializing decoders per bit width, each decoder table will begin
+  // with the bitwidth for that table.
+  if (SpecializeDecodersPerBitwidth) {
+    emitStartLine();
+    emitULEB128(BitWidth);
+    emitComment(indent(0)) << "BitWidth " << BitWidth << '\n';
   }
+
+  emitNode(Root, indent(0));
+  assert(CurrentIndex == TableSize &&
+         "The size of the emitted table differs from the calculated one");
+
+  OS << "};\n";
 }
 
 static std::string findOperandDecoderMethod(const Record *Record) {
@@ -2100,21 +2260,9 @@ InstructionEncoding::InstructionEncoding(const Record *EncodingDef,
 // emitDecodeInstruction - Emit the templated helper function
 // decodeInstruction().
 static void emitDecodeInstruction(formatted_raw_ostream &OS, bool IsVarLenInst,
-                                  unsigned OpcodeMask) {
-  const bool HasTryDecode = OpcodeMask & (1 << MCD::OPC_TryDecode);
-  const bool HasCheckPredicate = OpcodeMask & (1 << MCD::OPC_CheckPredicate);
-  const bool HasSoftFail = OpcodeMask & (1 << MCD::OPC_SoftFail);
-
+                                  bool HasCheckPredicate, bool HasSoftFail,
+                                  bool HasTryDecode) {
   OS << R"(
-static unsigned decodeNumToSkip(const uint8_t *&Ptr) {
-  unsigned NumToSkip = *Ptr++;
-  NumToSkip |= (*Ptr++) << 8;
-)";
-  if (getNumToSkipInBytes() == 3)
-    OS << "  NumToSkip |= (*Ptr++) << 16;\n";
-  OS << R"(  return NumToSkip;
-}
-
 template <typename InsnType>
 static DecodeStatus decodeInstruction(const uint8_t DecodeTable[], MCInst &MI,
                                       InsnType insn, uint64_t Address,
@@ -2142,7 +2290,6 @@ static DecodeStatus decodeInstruction(const uint8_t DecodeTable[], MCInst &MI,
 
   OS << R"(
   SmallVector<const uint8_t *, 8> ScopeStack;
-  uint64_t CurFieldValue = 0;
   DecodeStatus S = MCDisassembler::Success;
   while (true) {
     ptrdiff_t Loc = Ptr - DecodeTable;
@@ -2153,51 +2300,34 @@ static DecodeStatus decodeInstruction(const uint8_t DecodeTable[], MCInst &MI,
              << (int)DecoderOp << '\n';
       return MCDisassembler::Fail;
     case MCD::OPC_Scope: {
-      unsigned NumToSkip = decodeNumToSkip(Ptr);
+      unsigned NumToSkip = decodeULEB128AndIncUnsafe(Ptr);
       const uint8_t *SkipTo = Ptr + NumToSkip;
       ScopeStack.push_back(SkipTo);
       LLVM_DEBUG(dbgs() << Loc << ": OPC_Scope(" << SkipTo - DecodeTable
                         << ")\n");
       break;
     }
-    case MCD::OPC_ExtractField: {
+    case MCD::OPC_SwitchField: {
       // Decode the start value.
       unsigned Start = decodeULEB128AndIncUnsafe(Ptr);
       unsigned Len = *Ptr++;)";
   if (IsVarLenInst)
     OS << "\n      makeUp(insn, Start + Len);";
   OS << R"(
-      CurFieldValue = fieldFromInstruction(insn, Start, Len);
-      LLVM_DEBUG(dbgs() << Loc << ": OPC_ExtractField(" << Start << ", "
-                   << Len << "): " << CurFieldValue << "\n");
-      break;
-    }
-    case MCD::OPC_FilterValueOrSkip: {
-      // Decode the field value.
-      uint64_t Val = decodeULEB128AndIncUnsafe(Ptr);
-      bool Failed = Val != CurFieldValue;
-      unsigned NumToSkip = decodeNumToSkip(Ptr);
-      const uint8_t *SkipTo = Ptr + NumToSkip;
-
-      LLVM_DEBUG(dbgs() << Loc << ": OPC_FilterValueOrSkip(" << Val << ", "
-                        << SkipTo - DecodeTable << ") "
-                        << (Failed ? "FAIL, " : "PASS\n"));
-
-      if (Failed) {
-        Ptr = SkipTo;
-        LLVM_DEBUG(dbgs() << "continuing at " << Ptr - DecodeTable << '\n');
+      uint64_t FieldVal = fieldFromInstruction(insn, Start, Len);
+      uint64_t CaseVal;
+      unsigned CaseSize;
+      while (true) {
+        CaseVal = decodeULEB128AndIncUnsafe(Ptr);
+        CaseSize = decodeULEB128AndIncUnsafe(Ptr);
+        if (FieldVal == CaseVal || !CaseSize)
+          break;
+        Ptr += CaseSize;
       }
-      break;
-    }
-    case MCD::OPC_FilterValue: {
-      // Decode the field value.
-      uint64_t Val = decodeULEB128AndIncUnsafe(Ptr);
-      bool Failed = Val != CurFieldValue;
-
-      LLVM_DEBUG(dbgs() << Loc << ": OPC_FilterValue(" << Val << ") "
-                        << (Failed ? "FAIL, " : "PASS\n"));
-
-      if (Failed) {
+      if (FieldVal == CaseVal) {
+        LLVM_DEBUG(dbgs() << Loc << ": OPC_SwitchField(" << Start << ", " << Len
+                          << "): " << FieldVal << '\n');
+      } else {
         if (ScopeStack.empty()) {
           LLVM_DEBUG(dbgs() << "returning Fail\n");
           return MCDisassembler::Fail;
@@ -2573,9 +2703,12 @@ template <typename T> constexpr uint32_t InsnBitWidth = 0;
   // Entries in `EncMap` are already sorted by bitwidth. So bucketing per
   // bitwidth can be done on-the-fly as we iterate over the map.
   DecoderTableInfo TableInfo;
-  DecoderTableBuilder TableBuilder(Target, Encodings, TableInfo);
-  unsigned OpcodeMask = 0;
+  DecoderTreeBuilder TreeBuilder(Target, Encodings, TableInfo);
+  DecoderTableEmitter TableEmitter(Target, Encodings, OS);
 
+  // Emit a table for each (namespace, hwmode, width) combination.
+  // If `SpecializeDecodersPerBitwidth` is enabled, emit a decoder function
+  // for each table.
   bool HasConflict = false;
   for (const auto &[BitWidth, BWMap] : EncMap) {
     for (const auto &[Key, EncodingIDs] : BWMap) {
@@ -2588,19 +2721,16 @@ template <typename T> constexpr uint32_t InsnBitWidth = 0;
       if (HasConflict)
         continue;
 
-      // The decode table is cleared for each top level decoder function. The
-      // predicates and decoders themselves, however, are shared across
-      // different decoders to give more opportunities for uniqueing.
-      //  - If `SpecializeDecodersPerBitwidth` is enabled, decoders are shared
-      //    across all decoder tables for a given bitwidth, else they are shared
-      //    across all decoder tables.
-      //  - predicates are shared across all decoder tables.
-      TableInfo.Table.clear();
-      TableBuilder.buildTable(FC, BitWidth);
-
-      // Print the table to the output stream.
-      OpcodeMask |= emitTable(OS, TableInfo.Table, DecoderNamespace, HwModeID,
-                              BitWidth, EncodingIDs);
+      std::unique_ptr<DecoderTreeNode> Root = TreeBuilder.buildTree(FC);
+
+      SmallString<32> TableName("DecoderTable");
+      TableName.append(DecoderNamespace);
+      if (HwModeID != DefaultMode)
+        TableName.append({"_", Target.getHwModes().getModeName(HwModeID)});
+      TableName.append(std::to_string(BitWidth));
+
+      // Serialize the tree.
+      TableEmitter.emitTable(TableName, BitWidth, Root.get());
     }
 
     // Each BitWidth get's its own decoders and decoder function if
@@ -2619,14 +2749,14 @@ template <typename T> constexpr uint32_t InsnBitWidth = 0;
   if (!SpecializeDecodersPerBitwidth)
     emitDecoderFunction(OS, TableInfo.Decoders, 0);
 
-  const bool HasCheckPredicate = OpcodeMask & (1 << MCD::OPC_CheckPredicate);
-
   // Emit the predicate function.
-  if (HasCheckPredicate)
+  if (TableEmitter.hasCheckPredicate())
     emitPredicateFunction(OS, TableInfo.Predicates);
 
   // Emit the main entry point for the decoder, decodeInstruction().
-  emitDecodeInstruction(OS, IsVarLenInst, OpcodeMask);
+  emitDecodeInstruction(OS, IsVarLenInst, TableEmitter.hasCheckPredicate(),
+                        TableEmitter.hasSoftFail(),
+                        TableEmitter.hasTryDecode());
 
   OS << "\n} // namespace\n";
 }



More information about the llvm-commits mailing list