[llvm] [TableGen][DecoderEmitter] Add option to emit type-specialized `decodeToMCInst` (PR #146593)

Rahul Joshi via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 14 13:50:58 PDT 2025


================
@@ -1048,90 +1064,150 @@ void DecoderEmitter::emitInstrLenTable(formatted_raw_ostream &OS,
 }
 
 void DecoderEmitter::emitPredicateFunction(formatted_raw_ostream &OS,
-                                           PredicateSet &Predicates,
-                                           indent Indent) const {
+                                           PredicateSet &Predicates) const {
   // The predicate function is just a big switch statement based on the
   // input predicate index.
-  OS << Indent << "static bool checkDecoderPredicate(unsigned Idx, "
+  OS << "static bool checkDecoderPredicate(unsigned Idx, "
      << "const FeatureBitset &Bits) {\n";
-  Indent += 2;
-  OS << Indent << "switch (Idx) {\n";
-  OS << Indent << "default: llvm_unreachable(\"Invalid index!\");\n";
+  OS << "  switch (Idx) {\n";
+  OS << "    default: llvm_unreachable(\"Invalid index!\");\n";
   for (const auto &[Index, Predicate] : enumerate(Predicates)) {
-    OS << Indent << "case " << Index << ":\n";
-    OS << Indent + 2 << "return (" << Predicate << ");\n";
+    OS << "    case " << Index << ":\n";
+    OS << "       return (" << Predicate << ");\n";
   }
-  OS << Indent << "}\n";
-  Indent -= 2;
-  OS << Indent << "}\n\n";
+  OS << "  }\n";
+  OS << "}\n\n";
+}
+
+// ----------------------------------------------------------------------------
+// CPPType implementation.
+
+CPPType::CPPType(unsigned Bitwidth, bool IsVarLenInst) : Bitwidth(Bitwidth) {
+  if (IsVarLenInst)
+    Kind = APIntTy;
+  else if (Bitwidth == 0)
+    Kind = TemplateTy;
+  else if (Bitwidth == 8 || Bitwidth == 16 || Bitwidth == 32 || Bitwidth == 64)
+    Kind = UIntTy;
+  else
+    Kind = BitsetTy;
+}
+
+std::string CPPType::getName() const {
+  switch (Kind) {
+  case TemplateTy:
+    return "InsnType";
+  case UIntTy:
+    return "uint" + std::to_string(Bitwidth) + "_t";
+  case APIntTy:
+    return "APInt";
+  case BitsetTy:
+    return "std::bitset<" + std::to_string(Bitwidth) + ">";
+  }
+  llvm_unreachable("Unexpected kind");
+}
+
+std::string CPPType::getParamDecl() const {
+  switch (Kind) {
+  case TemplateTy:
+  case BitsetTy:
+    return "const " + getName() + " &insn";
+  case UIntTy:
+    return getName() + " insn";
+  case APIntTy:
+    return "APInt &insn";
+  }
+  llvm_unreachable("Unexpected kind");
+}
+
+static void emitTemplate(formatted_raw_ostream &OS, const CPPType &Type) {
+  if (Type.Kind == CPPType::TemplateTy)
+    OS << "template <typename InsnType>\n";
 }
 
 void DecoderEmitter::emitDecoderFunction(formatted_raw_ostream &OS,
                                          DecoderSet &Decoders,
-                                         indent Indent) const {
+                                         const CPPType &Type,
+                                         StringRef Suffix) const {
   // The decoder function is just a big switch statement or a table of function
   // pointers based on the input decoder index.
+  const std::string TypeName = Type.getName();
+  const std::string TypeParamDecl = Type.getParamDecl();
 
   // TODO: When InsnType is large, using uint64_t limits all fields to 64 bits
   // It would be better for emitBinaryParser to use a 64-bit tmp whenever
   // possible but fall back to an InsnType-sized tmp for truly large fields.
-  StringRef TmpTypeDecl =
-      "using TmpType = std::conditional_t<std::is_integral<InsnType>::value, "
-      "InsnType, uint64_t>;\n";
-  StringRef DecodeParams =
-      "DecodeStatus S, InsnType insn, MCInst &MI, uint64_t Address, const "
-      "MCDisassembler *Decoder, bool &DecodeComplete";
+  auto emitTmpTypeDec = [&Type, &TypeName, &OS]() {
+    if (Type.Kind == CPPType::TemplateTy)
+      OS << formatv(
+          "  using TmpType = std::conditional_t<std::is_integral<{0}>::value, "
+          "{0}, uint64_t>;\n",
+          TypeName);
+  };
+
+  // Returns the type to use for the `tmp` variable.
+  StringRef TmpType = [&Type, &TypeName]() -> StringRef {
+    switch (Type.Kind) {
+    case CPPType::TemplateTy:
+      return "TmpType";
+    case CPPType::UIntTy:
+      return TypeName;
+    default:
+      return "uint64_t";
+    }
+  }();
+
+  auto DecodeParams =
+      formatv("DecodeStatus S, {}, MCInst &MI, uint64_t Address, "
+              "const MCDisassembler *Decoder, bool &DecodeComplete",
+              TypeParamDecl);
 
   if (UseFnTableInDecodeToMCInst) {
     // Emit a function for each case first.
     for (const auto &[Index, Decoder] : enumerate(Decoders)) {
-      OS << Indent << "template <typename InsnType>\n";
-      OS << Indent << "DecodeStatus decodeFn" << Index << "(" << DecodeParams
-         << ") {\n";
-      Indent += 2;
-      OS << Indent << TmpTypeDecl;
-      OS << Indent << "[[maybe_unused]] TmpType tmp;\n";
+      emitTemplate(OS, Type);
+      OS << "DecodeStatus decodeFn" << Suffix << '_' << Index << "("
+         << DecodeParams << ") {\n";
+      emitTmpTypeDec();
+      OS << "  [[maybe_unused]] " << TmpType << " tmp;\n";
       OS << Decoder;
-      OS << Indent << "return S;\n";
-      Indent -= 2;
-      OS << Indent << "}\n\n";
+      OS << "  return S;\n";
+      OS << "}\n\n";
     }
   }
 
-  OS << Indent << "// Handling " << Decoders.size() << " cases.\n";
-  OS << Indent << "template <typename InsnType>\n";
-  OS << Indent << "static DecodeStatus decodeToMCInst(unsigned Idx, "
+  assert(!Decoders.empty() && "Did not find any decoders");
+
+  OS << "// Handling " << Decoders.size() << " cases.\n";
+  emitTemplate(OS, Type);
+  OS << "static DecodeStatus decodeToMCInst" << Suffix << "(unsigned Idx, "
      << DecodeParams << ") {\n";
-  Indent += 2;
-  OS << Indent << "DecodeComplete = true;\n";
+  OS << "  DecodeComplete = true;\n";
 
   if (UseFnTableInDecodeToMCInst) {
     // Build a table of function pointers.
-    OS << Indent << "using DecodeFnTy = DecodeStatus (*)(" << DecodeParams
-       << ");\n";
-    OS << Indent << "static constexpr DecodeFnTy decodeFnTable[] = {\n";
+    OS << "  using DecodeFnTy = DecodeStatus (*)(" << DecodeParams << ");\n";
+    OS << "  static constexpr DecodeFnTy decodeFnTable[] = {\n";
     for (size_t Index : llvm::seq(Decoders.size()))
-      OS << Indent + 2 << "decodeFn" << Index << ",\n";
-    OS << Indent << "};\n";
-    OS << Indent << "if (Idx >= " << Decoders.size() << ")\n";
-    OS << Indent + 2 << "llvm_unreachable(\"Invalid index!\");\n";
-    OS << Indent
-       << "return decodeFnTable[Idx](S, insn, MI, Address, Decoder, "
+      OS << "    decodeFn" << Suffix << '_' << Index << ",\n";
+    OS << "  };\n";
+    OS << "  if (Idx >= " << Decoders.size() << ")\n";
+    OS << "    llvm_unreachable(\"Invalid index!\");\n";
+    OS << "return decodeFnTable[Idx](S, insn, MI, Address, Decoder, "
           "DecodeComplete);\n";
   } else {
-    OS << Indent << TmpTypeDecl;
-    OS << Indent << "TmpType tmp;\n";
-    OS << Indent << "switch (Idx) {\n";
-    OS << Indent << "default: llvm_unreachable(\"Invalid index!\");\n";
+    emitTmpTypeDec();
+    OS << "  " << TmpType << " tmp;\n";
+    OS << "  switch (Idx) {\n";
+    OS << "    default: llvm_unreachable(\"Invalid index!\");\n";
     for (const auto &[Index, Decoder] : enumerate(Decoders)) {
-      OS << Indent << "case " << Index << ":\n";
+      OS << "    case " << Index << ":\n";
----------------
jurahul wrote:

Handling this in https://github.com/llvm/llvm-project/pull/148718.

https://github.com/llvm/llvm-project/pull/146593


More information about the llvm-commits mailing list