[llvm] [TableGen][DecoderEmitter] Rework table construction/emission (PR #155889)
Rahul Joshi via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 17 11:48:35 PDT 2025
================
@@ -1718,6 +1245,619 @@ static DecodeStatus decodeInstruction(const uint8_t DecodeTable[], MCInst &MI,
)";
}
+namespace {
+
+class DecoderTreeNode {
+public:
+ virtual ~DecoderTreeNode() = default;
+
+ enum KindTy {
+ CheckAny,
+ CheckAll,
+ CheckField,
+ SwitchField,
+ CheckPredicate,
+ SoftFail,
+ Decode,
+ };
+
+ KindTy getKind() const { return Kind; }
+
+protected:
+ explicit DecoderTreeNode(KindTy Kind) : Kind(Kind) {}
+
+private:
+ KindTy Kind;
+};
+
+class CheckAnyNode : public DecoderTreeNode {
+ SmallVector<std::unique_ptr<DecoderTreeNode>, 0> Children;
+
+ static const DecoderTreeNode *
+ mapElement(decltype(Children)::const_reference Element) {
+ return Element.get();
+ }
+
+public:
+ CheckAnyNode() : DecoderTreeNode(CheckAny) {}
+
+ void addChild(std::unique_ptr<DecoderTreeNode> N) {
+ Children.push_back(std::move(N));
+ }
+
+ using child_iterator = mapped_iterator<decltype(Children)::const_iterator,
+ decltype(&mapElement)>;
+
+ child_iterator child_begin() const {
+ return child_iterator(Children.begin(), mapElement);
+ }
+
+ child_iterator child_end() const {
+ return child_iterator(Children.end(), mapElement);
+ }
+
+ iterator_range<child_iterator> children() const {
+ return make_range(child_begin(), child_end());
+ }
+};
+
+class CheckAllNode : public DecoderTreeNode {
+ SmallVector<std::unique_ptr<DecoderTreeNode>, 0> Children;
+
+ static const DecoderTreeNode *
+ mapElement(decltype(Children)::const_reference Element) {
+ return Element.get();
+ }
+
+public:
+ CheckAllNode() : DecoderTreeNode(CheckAll) {}
+
+ void addChild(std::unique_ptr<DecoderTreeNode> Child) {
+ Children.push_back(std::move(Child));
+ }
+
+ using child_iterator = mapped_iterator<decltype(Children)::const_iterator,
+ decltype(&mapElement)>;
+
+ child_iterator child_begin() const {
+ return child_iterator(Children.begin(), mapElement);
+ }
+
+ child_iterator child_end() const {
+ return child_iterator(Children.end(), mapElement);
+ }
+
+ iterator_range<child_iterator> children() const {
+ return make_range(child_begin(), child_end());
+ }
+};
+
+class CheckFieldNode : public DecoderTreeNode {
+ unsigned StartBit;
+ unsigned NumBits;
+ uint64_t Value;
+
+public:
+ CheckFieldNode(unsigned StartBit, unsigned NumBits, uint64_t Value)
+ : DecoderTreeNode(CheckField), StartBit(StartBit), NumBits(NumBits),
+ Value(Value) {}
+
+ unsigned getStartBit() const { return StartBit; }
+
+ unsigned getNumBits() const { return NumBits; }
+
+ uint64_t getValue() const { return Value; }
+};
+
+class SwitchFieldNode : public DecoderTreeNode {
+ unsigned StartBit;
+ unsigned NumBits;
+ std::map<uint64_t, std::unique_ptr<DecoderTreeNode>> Cases;
+
+ static std::pair<uint64_t, const DecoderTreeNode *>
+ mapElement(decltype(Cases)::const_reference Element) {
+ return std::pair(Element.first, Element.second.get());
+ }
+
+public:
+ SwitchFieldNode(unsigned StartBit, unsigned NumBits)
+ : DecoderTreeNode(SwitchField), StartBit(StartBit), NumBits(NumBits) {}
+
+ void addCase(uint64_t Value, std::unique_ptr<DecoderTreeNode> N) {
+ Cases.try_emplace(Value, std::move(N));
+ }
+
+ unsigned getStartBit() const { return StartBit; }
+
+ unsigned getNumBits() const { return NumBits; }
+
+ using case_iterator =
+ mapped_iterator<decltype(Cases)::const_iterator, decltype(&mapElement)>;
+
+ case_iterator case_begin() const {
+ return case_iterator(Cases.begin(), mapElement);
+ }
+
+ case_iterator case_end() const {
+ return case_iterator(Cases.end(), mapElement);
+ }
+
+ iterator_range<case_iterator> cases() const {
+ return make_range(case_begin(), case_end());
+ }
+};
+
+class CheckPredicateNode : public DecoderTreeNode {
+ std::string PredicateString;
+
+public:
+ explicit CheckPredicateNode(std::string PredicateString)
+ : DecoderTreeNode(CheckPredicate),
+ PredicateString(std::move(PredicateString)) {}
+
+ StringRef getPredicateString() const { return PredicateString; }
+};
+
+class SoftFailNode : public DecoderTreeNode {
+ uint64_t PositiveMask, NegativeMask;
+
+public:
+ SoftFailNode(uint64_t PositiveMask, uint64_t NegativeMask)
+ : DecoderTreeNode(SoftFail), PositiveMask(PositiveMask),
+ NegativeMask(NegativeMask) {}
+
+ uint64_t getPositiveMask() const { return PositiveMask; }
+ uint64_t getNegativeMask() const { return NegativeMask; }
+};
+
+class DecodeNode : public DecoderTreeNode {
+ const InstructionEncoding &Encoding;
+ std::string DecoderString;
+
+public:
+ DecodeNode(const InstructionEncoding &Encoding, std::string DecoderString)
+ : DecoderTreeNode(Decode), Encoding(Encoding),
+ DecoderString(std::move(DecoderString)) {}
+
+ const InstructionEncoding &getEncoding() const { return Encoding; }
+
+ StringRef getDecoderString() const { return DecoderString; }
+};
+
+class DecoderTreeBuilder {
+ const CodeGenTarget &Target;
+ ArrayRef<InstructionEncoding> Encodings;
+
+public:
+ DecoderTreeBuilder(const CodeGenTarget &Target,
+ ArrayRef<InstructionEncoding> Encodings)
+ : Target(Target), Encodings(Encodings) {}
+
+ std::unique_ptr<DecoderTreeNode> buildTree(const FilterChooser &FC) {
+ return buildCheckAnyNode(FC);
+ }
+
+private:
+ std::unique_ptr<DecoderTreeNode>
+ buildTerminalNode(unsigned EncodingID, const KnownBits &FilterBits);
+
+ std::unique_ptr<DecoderTreeNode> buildCheckAllOrSwitchNode(
+ unsigned StartBit, unsigned NumBits,
+ const std::map<uint64_t, std::unique_ptr<const FilterChooser>> &FCMap);
+
+ std::unique_ptr<DecoderTreeNode> buildCheckAnyNode(const FilterChooser &FC);
+};
+
+class DecoderTableEmitter {
+ DecoderTableInfo &TableInfo;
+ formatted_raw_ostream OS;
+ unsigned IndexWidth;
+ unsigned CurrentIndex;
+ unsigned CommentIndex;
+
+public:
+ DecoderTableEmitter(DecoderTableInfo &TableInfo, raw_ostream &OS)
+ : TableInfo(TableInfo), OS(OS) {}
+
+ void emitTable(StringRef TableName, unsigned BitWidth,
+ const DecoderTreeNode *Root);
+
+private:
+ void analyzeNode(const DecoderTreeNode *Node) const;
+
+ unsigned computeNodeSize(const DecoderTreeNode *Node) const;
+ unsigned computeTableSize(const DecoderTreeNode *Root,
+ unsigned BitWidth) const;
+
+ void emitStartLine();
+ void emitOpcode(StringRef Name);
+ void emitByte(uint8_t Val);
+ void emitUInt8(unsigned Val);
+ void emitULEB128(uint64_t Val);
+ formatted_raw_ostream &emitComment(indent Indent);
+
+ void emitCheckAnyNode(const CheckAnyNode *N, indent Indent);
+ void emitCheckAllNode(const CheckAllNode *N, indent Indent);
+ void emitSwitchFieldNode(const SwitchFieldNode *N, indent Indent);
+ void emitCheckFieldNode(const CheckFieldNode *N, indent Indent);
+ void emitCheckPredicateNode(const CheckPredicateNode *N, indent Indent);
+ void emitSoftFailNode(const SoftFailNode *N, indent Indent);
+ void emitDecodeNode(const DecodeNode *N, indent Indent);
+ void emitNode(const DecoderTreeNode *N, indent Indent);
+};
+
+} // namespace
+
+std::unique_ptr<DecoderTreeNode>
+DecoderTreeBuilder::buildTerminalNode(unsigned EncodingID,
+ const KnownBits &FilterBits) {
+ const InstructionEncoding &Encoding = Encodings[EncodingID];
+ auto N = std::make_unique<CheckAllNode>();
+
+ std::string Predicate = getPredicateString(Encoding, Target.getName());
+ if (!Predicate.empty())
+ N->addChild(std::make_unique<CheckPredicateNode>(std::move(Predicate)));
+
+ std::vector<FilterChooser::Island> Islands =
+ FilterChooser::getIslands(Encoding.getMandatoryBits(), FilterBits);
+ for (const FilterChooser::Island &Ilnd : reverse(Islands)) {
+ N->addChild(std::make_unique<CheckFieldNode>(Ilnd.StartBit, Ilnd.NumBits,
+ Ilnd.FieldVal));
+ }
+
+ const KnownBits &InstBits = Encoding.getInstBits();
+ const APInt &SoftFailMask = Encoding.getSoftFailMask();
+ if (!SoftFailMask.isZero()) {
+ APInt PositiveMask = InstBits.Zero & SoftFailMask;
+ APInt NegativeMask = InstBits.One & SoftFailMask;
+ N->addChild(std::make_unique<SoftFailNode>(PositiveMask.getZExtValue(),
+ NegativeMask.getZExtValue()));
+ }
+
+ std::string DecoderIndex = getDecoderString(Encoding);
+ N->addChild(std::make_unique<DecodeNode>(Encoding, DecoderIndex));
+
+ return N;
+}
+
+std::unique_ptr<DecoderTreeNode> DecoderTreeBuilder::buildCheckAllOrSwitchNode(
+ unsigned StartBit, unsigned NumBits,
+ const std::map<uint64_t, std::unique_ptr<const FilterChooser>> &FCMap) {
+ if (FCMap.size() == 1) {
+ const auto &[FieldVal, ChildFC] = *FCMap.begin();
+ auto N = std::make_unique<CheckAllNode>();
+ N->addChild(std::make_unique<CheckFieldNode>(StartBit, NumBits, FieldVal));
+ N->addChild(buildCheckAnyNode(*ChildFC));
+ return N;
+ }
+ auto N = std::make_unique<SwitchFieldNode>(StartBit, NumBits);
+ for (const auto &[FieldVal, ChildFC] : FCMap)
+ N->addCase(FieldVal, buildCheckAnyNode(*ChildFC));
+ return N;
+}
+
+std::unique_ptr<DecoderTreeNode>
+DecoderTreeBuilder::buildCheckAnyNode(const FilterChooser &FC) {
+ auto N = std::make_unique<CheckAnyNode>();
+ if (FC.SingletonEncodingID) {
+ N->addChild(buildTerminalNode(*FC.SingletonEncodingID, FC.FilterBits));
+ } else {
+ N->addChild(buildCheckAllOrSwitchNode(FC.StartBit, FC.NumBits,
+ FC.FilterChooserMap));
+ }
+ if (FC.VariableFC) {
+ N->addChild(buildCheckAnyNode(*FC.VariableFC));
+ }
+
+ return N;
+}
+
+void DecoderTableEmitter::analyzeNode(const DecoderTreeNode *Node) const {
+ switch (Node->getKind()) {
+ case DecoderTreeNode::CheckAny: {
+ const auto *N = static_cast<const CheckAnyNode *>(Node);
+ for (const DecoderTreeNode *Child : N->children())
+ analyzeNode(Child);
+ break;
+ }
+ case DecoderTreeNode::CheckAll: {
+ const auto *N = static_cast<const CheckAllNode *>(Node);
+ for (const DecoderTreeNode *Child : N->children())
+ analyzeNode(Child);
+ break;
+ }
+ case DecoderTreeNode::CheckField:
+ break;
+ case DecoderTreeNode::SwitchField: {
+ const auto *N = static_cast<const SwitchFieldNode *>(Node);
+ for (const DecoderTreeNode *Child : make_second_range(N->cases()))
+ analyzeNode(Child);
+ break;
+ }
+ case DecoderTreeNode::CheckPredicate: {
+ const auto *N = static_cast<const CheckPredicateNode *>(Node);
+ TableInfo.insertPredicate(N->getPredicateString());
+ break;
+ }
+ case DecoderTreeNode::SoftFail:
+ break;
+ case DecoderTreeNode::Decode: {
+ const auto *N = static_cast<const DecodeNode *>(Node);
+ TableInfo.insertDecoder(N->getDecoderString());
+ break;
+ }
+ }
+}
+
+unsigned
+DecoderTableEmitter::computeNodeSize(const DecoderTreeNode *Node) const {
+ switch (Node->getKind()) {
+ case DecoderTreeNode::CheckAny: {
+ const auto *N = static_cast<const CheckAnyNode *>(Node);
+ unsigned Size = 0;
+ for (const DecoderTreeNode *Child : drop_end(N->children())) {
+ unsigned ChildSize = computeNodeSize(Child);
+ Size += 1 + getULEB128Size(ChildSize) + ChildSize;
+ }
+ return Size + computeNodeSize(*std::prev(N->child_end()));
+ }
+ case DecoderTreeNode::CheckAll: {
+ const auto *N = static_cast<const CheckAllNode *>(Node);
+ unsigned Size = 0;
+ for (const DecoderTreeNode *Child : N->children())
+ Size += computeNodeSize(Child);
+ return Size;
+ }
+ case DecoderTreeNode::CheckField: {
+ const auto *N = static_cast<const CheckFieldNode *>(Node);
+ return 1 + getULEB128Size(N->getStartBit()) + 1 +
+ getULEB128Size(N->getValue());
+ }
+ case DecoderTreeNode::SwitchField: {
+ const auto *N = static_cast<const SwitchFieldNode *>(Node);
+ unsigned Size = 1 + getULEB128Size(N->getStartBit()) + 1;
+
+ for (auto [Val, Child] : drop_end(N->cases())) {
+ unsigned ChildSize = computeNodeSize(Child);
+ Size += getULEB128Size(Val) + getULEB128Size(ChildSize) + ChildSize;
+ }
+
+ auto [Val, Child] = *std::prev(N->case_end());
+ unsigned ChildSize = computeNodeSize(Child);
+ Size += getULEB128Size(Val) + getULEB128Size(0) + ChildSize;
+ return Size;
+ }
+ case DecoderTreeNode::CheckPredicate: {
+ const auto *N = static_cast<const CheckPredicateNode *>(Node);
+ unsigned PredicateIndex =
+ TableInfo.getPredicateIndex(N->getPredicateString());
+ return 1 + getULEB128Size(PredicateIndex);
+ }
+ case DecoderTreeNode::SoftFail: {
+ const auto *N = static_cast<const SoftFailNode *>(Node);
+ return 1 + getULEB128Size(N->getPositiveMask()) +
+ getULEB128Size(N->getNegativeMask());
+ }
+ case DecoderTreeNode::Decode: {
+ const auto *N = static_cast<const DecodeNode *>(Node);
+ unsigned InstOpcode = N->getEncoding().getInstruction()->EnumVal;
+ unsigned DecoderIndex = TableInfo.getDecoderIndex(N->getDecoderString());
+ return 1 + getULEB128Size(InstOpcode) + getULEB128Size(DecoderIndex);
+ }
+ }
+ llvm_unreachable("Unknown node kind");
+}
+
+unsigned DecoderTableEmitter::computeTableSize(const DecoderTreeNode *Root,
+ unsigned BitWidth) const {
+ unsigned Size = 0;
+ if (SpecializeDecodersPerBitwidth)
+ Size += getULEB128Size(BitWidth);
+ Size += computeNodeSize(Root);
+ return Size;
+}
+
+void DecoderTableEmitter::emitStartLine() {
+ CommentIndex = CurrentIndex;
+ OS.indent(2);
+}
+
+void DecoderTableEmitter::emitOpcode(StringRef Name) {
+ emitStartLine();
+ OS << Name << ", ";
+ ++CurrentIndex;
+}
+
+void DecoderTableEmitter::emitByte(uint8_t Val) {
+ OS << static_cast<unsigned>(Val) << ", ";
+ ++CurrentIndex;
+}
+
+void DecoderTableEmitter::emitUInt8(unsigned Val) {
+ assert(isUInt<8>(Val));
+ emitByte(Val);
+}
+
+void DecoderTableEmitter::emitULEB128(uint64_t Val) {
+ while (Val >= 0x80) {
+ emitByte((Val & 0x7F) | 0x80);
+ Val >>= 7;
+ }
+ emitByte(Val);
+}
+
+formatted_raw_ostream &DecoderTableEmitter::emitComment(indent Indent) {
+ constexpr unsigned CommentColumn = 45;
+ if (OS.getColumn() > CommentColumn)
+ OS << '\n';
+ OS.PadToColumn(CommentColumn);
+ OS << "// " << format_decimal(CommentIndex, IndexWidth) << ": " << Indent;
+ return OS;
+}
+
+void DecoderTableEmitter::emitCheckAnyNode(const CheckAnyNode *N,
+ indent Indent) {
+ for (const DecoderTreeNode *Child : drop_end(N->children())) {
+ emitOpcode("OPC_Scope");
+ emitULEB128(computeNodeSize(Child));
+
+ emitComment(Indent) << "{\n";
+ emitNode(Child, Indent + 1);
+ emitComment(Indent) << "}\n";
+ }
+
+ const DecoderTreeNode *Child = *std::prev(N->child_end());
+ emitNode(Child, Indent);
+}
+
+void DecoderTableEmitter::emitCheckAllNode(const CheckAllNode *N,
+ indent Indent) {
+ for (const DecoderTreeNode *Child : N->children())
+ emitNode(Child, Indent);
+}
+
+void DecoderTableEmitter::emitSwitchFieldNode(const SwitchFieldNode *N,
+ indent Indent) {
+ unsigned LSB = N->getStartBit();
+ unsigned Width = N->getNumBits();
+ unsigned MSB = LSB + Width - 1;
+
+ emitOpcode("OPC_SwitchField");
+ emitULEB128(LSB);
+ emitUInt8(Width);
+
+ emitComment(Indent) << "switch Inst[" << MSB << ':' << LSB << "] {\n";
+
+ for (auto [Val, Child] : drop_end(N->cases())) {
+ emitStartLine();
+ emitULEB128(Val);
+ emitULEB128(computeNodeSize(Child));
+
+ emitComment(Indent) << "case " << format_hex(Val, 0) << ": {\n";
+ emitNode(Child, Indent + 1);
+ emitComment(Indent) << "}\n";
+ }
+
+ auto [Val, Child] = *std::prev(N->case_end());
+ emitStartLine();
+ emitULEB128(Val);
+ emitULEB128(0);
+
+ emitComment(Indent) << "case " << format_hex(Val, 0) << ": {\n";
+ emitNode(Child, Indent + 1);
+ emitComment(Indent) << "}\n";
+
+ emitComment(Indent) << "} // switch Inst[" << MSB << ':' << LSB << "]\n";
+}
+
+void DecoderTableEmitter::emitCheckFieldNode(const CheckFieldNode *N,
+ indent Indent) {
+ unsigned LSB = N->getStartBit();
+ unsigned Width = N->getNumBits();
+ unsigned MSB = LSB + Width - 1;
+ uint64_t Val = N->getValue();
+
+ emitOpcode("OPC_CheckField");
+ emitULEB128(LSB);
+ emitUInt8(Width);
+ emitULEB128(Val);
+
+ emitComment(Indent);
+ OS << "check Inst[" << MSB << ':' << LSB << "] == " << format_hex(Val, 0)
+ << '\n';
+}
+
+void DecoderTableEmitter::emitCheckPredicateNode(const CheckPredicateNode *N,
+ indent Indent) {
+ unsigned PredicateIndex =
+ TableInfo.getPredicateIndex(N->getPredicateString());
+
+ emitOpcode("OPC_CheckPredicate");
+ emitULEB128(PredicateIndex);
+ TableInfo.HasCheckPredicate = true;
+
+ emitComment(Indent) << "check predicate " << PredicateIndex << "\n";
+}
+
+void DecoderTableEmitter::emitSoftFailNode(const SoftFailNode *N,
+ indent Indent) {
+ uint64_t PositiveMask = N->getPositiveMask();
+ uint64_t NegativeMask = N->getNegativeMask();
+
+ emitOpcode("OPC_SoftFail");
+ emitULEB128(PositiveMask);
+ emitULEB128(NegativeMask);
+ TableInfo.HasSoftFail = true;
+
+ emitComment(Indent) << "check softfail";
+ OS << " pos=" << format_hex(PositiveMask, 10);
+ OS << " neg=" << format_hex(NegativeMask, 10) << '\n';
+}
+
+void DecoderTableEmitter::emitDecodeNode(const DecodeNode *N, indent Indent) {
+ const InstructionEncoding &Encoding = N->getEncoding();
+ unsigned InstOpcode = Encoding.getInstruction()->EnumVal;
+ unsigned DecoderIndex = TableInfo.getDecoderIndex(N->getDecoderString());
+
+ emitOpcode("OPC_Decode");
+ emitULEB128(InstOpcode);
+ emitULEB128(DecoderIndex);
+
+ emitComment(Indent) << "decode to " << Encoding.getName() << " using decoder "
+ << DecoderIndex << '\n';
+}
+
+void DecoderTableEmitter::emitNode(const DecoderTreeNode *N, indent Indent) {
+ switch (N->getKind()) {
+ case DecoderTreeNode::CheckAny:
+ return emitCheckAnyNode(static_cast<const CheckAnyNode *>(N), Indent);
+ case DecoderTreeNode::CheckAll:
+ return emitCheckAllNode(static_cast<const CheckAllNode *>(N), Indent);
+ case DecoderTreeNode::SwitchField:
+ return emitSwitchFieldNode(static_cast<const SwitchFieldNode *>(N), Indent);
+ case DecoderTreeNode::CheckField:
+ return emitCheckFieldNode(static_cast<const CheckFieldNode *>(N), Indent);
+ case DecoderTreeNode::CheckPredicate:
+ return emitCheckPredicateNode(static_cast<const CheckPredicateNode *>(N),
+ Indent);
+ case DecoderTreeNode::SoftFail:
+ return emitSoftFailNode(static_cast<const SoftFailNode *>(N), Indent);
+ case DecoderTreeNode::Decode:
+ return emitDecodeNode(static_cast<const DecodeNode *>(N), Indent);
+ }
+ llvm_unreachable("Unknown node kind");
+}
+
+void DecoderTableEmitter::emitTable(StringRef TableName, unsigned BitWidth,
+ const DecoderTreeNode *Root) {
+ analyzeNode(Root);
+
+ unsigned TableSize = computeTableSize(Root, BitWidth);
+ OS << "static const uint8_t " << TableName << "[" << TableSize << "] = {\n";
+
+ // Calculate the number of decimal places for table indices.
+ // This is simply log10 of the table size.
+ IndexWidth = 1;
----------------
jurahul wrote:
nit: llvm/include/llvm/DebugInfo/PDB/Native/FormatUtil.h has
```
/// Returns the number of digits in the given integer.
inline int NumDigits(uint64_t N) {
```
maybe we can extract it into MathExtras.h or Format.h at a future point.
https://github.com/llvm/llvm-project/pull/155889
More information about the llvm-commits
mailing list