[llvm] 31f9519 - [TableGen][CodeEmitter] Introducing the VarLenCodeEmitterGen infrastructure

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 11 09:32:54 PST 2022


Author: Min-Yih Hsu
Date: 2022-02-11T09:31:11-08:00
New Revision: 31f9519d48c26bd542253cc20dc333732c991332

URL: https://github.com/llvm/llvm-project/commit/31f9519d48c26bd542253cc20dc333732c991332
DIFF: https://github.com/llvm/llvm-project/commit/31f9519d48c26bd542253cc20dc333732c991332.diff

LOG: [TableGen][CodeEmitter] Introducing the VarLenCodeEmitterGen infrastructure

Full write up:
https://gist.github.com/mshockwave/66e98d099256deefc062633909bb7b5b

The existing CodeEmitterGen infrastructure is unable to generate encoder
function for ISAs with variable-length instructions. This patch
introduces a new infrastructure to support variable-length instruction
encoding, including a new TableGen syntax for writing instruction
encoding directives and a new TableGen backend component,
VarLenCodeEmitterGen, built on top of CodeEmitterGen.

Differential Revision: https://reviews.llvm.org/D115128

Added: 
    llvm/test/TableGen/VarLenEncoder.td
    llvm/utils/TableGen/VarLenCodeEmitterGen.cpp
    llvm/utils/TableGen/VarLenCodeEmitterGen.h

Modified: 
    llvm/include/llvm/Target/Target.td
    llvm/utils/TableGen/CMakeLists.txt
    llvm/utils/TableGen/CodeEmitterGen.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Target/Target.td b/llvm/include/llvm/Target/Target.td
index 44ae273f39103..8e93a6e6fd6ab 100644
--- a/llvm/include/llvm/Target/Target.td
+++ b/llvm/include/llvm/Target/Target.td
@@ -756,6 +756,28 @@ def ins;
 /// of operands.
 def variable_ops;
 
+/// variable-length instruction encoding utilities.
+/// The `ascend` operator should be used like this:
+///     (ascend 0b0000, 0b1111)
+/// Which represent a seqence of encoding fragments placing from LSB to MSB.
+/// Thus, in this case the final encoding will be 0b11110000.
+/// The arguments for `ascend` can either be `bits` or another DAG.
+def ascend;
+/// In addition, we can use `descend` to describe an encoding that places
+/// its arguments (i.e. encoding fragments) from MSB to LSB. For instance:
+///     (descend 0b0000, 0b1111)
+/// This results in an encoding of 0b00001111.
+def descend;
+/// The `operand` operator should be used like this:
+///     (operand "$src", 4)
+/// Which represents a 4-bit encoding for an instruction operand named `$src`.
+def operand;
+/// Similar to `operand`, we can reference only part of the operand's encoding:
+///     (slice "$src", 6, 8)
+///     (slice "$src", 8, 6)
+/// Both DAG represent bit 6 to 8 (total of 3 bits) in the encoding of operand
+/// `$src`.
+def slice;
 
 /// PointerLikeRegClass - Values that are designed to have pointer width are
 /// derived from this.  TableGen treats the register class as having a symbolic

diff  --git a/llvm/test/TableGen/VarLenEncoder.td b/llvm/test/TableGen/VarLenEncoder.td
new file mode 100644
index 0000000000000..a1ea389ffad23
--- /dev/null
+++ b/llvm/test/TableGen/VarLenEncoder.td
@@ -0,0 +1,93 @@
+// RUN: llvm-tblgen -gen-emitter -I %p/../../include %s | FileCheck %s
+
+// Check if VarLenCodeEmitterGen works correctly.
+
+include "llvm/Target/Target.td"
+
+def ArchInstrInfo : InstrInfo { }
+
+def Arch : Target {
+  let InstructionSet = ArchInstrInfo;
+}
+
+def Reg : Register<"reg">;
+
+def RegClass : RegisterClass<"foo", [i64], 0, (add Reg)>;
+
+def GR64 : RegisterOperand<RegClass>;
+
+class MyMemOperand<dag sub_ops> : Operand<iPTR> {
+  let MIOperandInfo = sub_ops;
+  dag Base;
+  dag Extension;
+}
+
+class MyVarInst<MyMemOperand memory_op> : Instruction {
+  dag Inst;
+
+  let OutOperandList = (outs GR64:$dst);
+  let InOperandList  = (ins memory_op:$src);
+
+  // Testing `ascend` and `descend`
+  let Inst = (ascend
+    (descend 0b10110111, memory_op.Base),
+    memory_op.Extension,
+    // Testing operand referencing.
+    (operand "$dst", 4),
+    // Testing operand referencing with a certain bit range.
+    (slice "$dst", 3, 1)
+  );
+}
+
+class MemOp16<string op_name> : MyMemOperand<(ops GR64:$reg, i16imm:$offset)> {
+  // Testing sub-operand referencing.
+  let Base = (operand "$"#op_name#".reg", 8);
+  let Extension = (operand "$"#op_name#".offset", 16);
+}
+
+class MemOp32<string op_name> : MyMemOperand<(ops GR64:$reg, i32imm:$offset)> {
+  let Base = (operand "$"#op_name#".reg", 8);
+  // Testing variable-length instruction encoding.
+  let Extension = (operand "$"#op_name#".offset", 32);
+}
+
+def FOO16 : MyVarInst<MemOp16<"src">>;
+def FOO32 : MyVarInst<MemOp32<"src">>;
+
+// The fixed bits part
+// CHECK: {/*NumBits*/39,
+// CHECK-SAME: // FOO16
+// CHECK: {/*NumBits*/55,
+// CHECK-SAME: // FOO32
+// CHECK: UINT64_C(46848), // FOO16
+// CHECK: UINT64_C(46848), // FOO32
+
+// CHECK-LABEL: case ::FOO16: {
+// CHECK: Scratch = Scratch.zextOrSelf(39);
+// src.reg
+// CHECK: getMachineOpValue(MI, MI.getOperand(1), Scratch, Fixups, STI);
+// CHECK: Inst.insertBits(Scratch.extractBits(8, 0), 0);
+// src.offset
+// CHECK: getMachineOpValue(MI, MI.getOperand(2), Scratch, Fixups, STI);
+// CHECK: Inst.insertBits(Scratch.extractBits(16, 0), 16);
+// 1st dst
+// CHECK: getMachineOpValue(MI, MI.getOperand(0), Scratch, Fixups, STI);
+// CHECK: Inst.insertBits(Scratch.extractBits(4, 0), 32);
+// 2nd dst
+// CHECK: getMachineOpValue(MI, MI.getOperand(0), Scratch, Fixups, STI);
+// CHECK: Inst.insertBits(Scratch.extractBits(3, 1), 36);
+
+// CHECK-LABEL: case ::FOO32: {
+// CHECK: Scratch = Scratch.zextOrSelf(55);
+// src.reg
+// CHECK: getMachineOpValue(MI, MI.getOperand(1), Scratch, Fixups, STI);
+// CHECK: Inst.insertBits(Scratch.extractBits(8, 0), 0);
+// src.offset
+// CHECK: getMachineOpValue(MI, MI.getOperand(2), Scratch, Fixups, STI);
+// CHECK: Inst.insertBits(Scratch.extractBits(32, 0), 16);
+// 1st dst
+// CHECK: getMachineOpValue(MI, MI.getOperand(0), Scratch, Fixups, STI);
+// CHECK: Inst.insertBits(Scratch.extractBits(4, 0), 48);
+// 2nd dst
+// CHECK: getMachineOpValue(MI, MI.getOperand(0), Scratch, Fixups, STI);
+// CHECK: Inst.insertBits(Scratch.extractBits(3, 1), 52);

diff  --git a/llvm/utils/TableGen/CMakeLists.txt b/llvm/utils/TableGen/CMakeLists.txt
index 97df6a55d1b59..339692bcd6512 100644
--- a/llvm/utils/TableGen/CMakeLists.txt
+++ b/llvm/utils/TableGen/CMakeLists.txt
@@ -49,6 +49,7 @@ add_tablegen(llvm-tblgen LLVM
   SubtargetFeatureInfo.cpp
   TableGen.cpp
   Types.cpp
+  VarLenCodeEmitterGen.cpp
   X86DisassemblerTables.cpp
   X86EVEX2VEXTablesEmitter.cpp
   X86FoldTablesEmitter.cpp

diff  --git a/llvm/utils/TableGen/CodeEmitterGen.cpp b/llvm/utils/TableGen/CodeEmitterGen.cpp
index fbac0d969917c..f446e5fe44148 100644
--- a/llvm/utils/TableGen/CodeEmitterGen.cpp
+++ b/llvm/utils/TableGen/CodeEmitterGen.cpp
@@ -16,6 +16,7 @@
 #include "CodeGenTarget.h"
 #include "SubtargetFeatureInfo.h"
 #include "Types.h"
+#include "VarLenCodeEmitterGen.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/StringExtras.h"
@@ -396,132 +397,140 @@ void CodeEmitterGen::run(raw_ostream &o) {
   ArrayRef<const CodeGenInstruction*> NumberedInstructions =
     Target.getInstructionsByEnumValue();
 
-  const CodeGenHwModes &HWM = Target.getHwModes();
-  // The set of HwModes used by instruction encodings.
-  std::set<unsigned> HwModes;
-  BitWidth = 0;
-  for (const CodeGenInstruction *CGI : NumberedInstructions) {
-    Record *R = CGI->TheDef;
-    if (R->getValueAsString("Namespace") == "TargetOpcode" ||
-        R->getValueAsBit("isPseudo"))
-      continue;
+  if (any_of(NumberedInstructions, [](const CodeGenInstruction *CGI) {
+        Record *R = CGI->TheDef;
+        return R->getValue("Inst") && isa<DagInit>(R->getValueInit("Inst"));
+      })) {
+    emitVarLenCodeEmitter(Records, o);
+  } else {
+    const CodeGenHwModes &HWM = Target.getHwModes();
+    // The set of HwModes used by instruction encodings.
+    std::set<unsigned> HwModes;
+    BitWidth = 0;
+    for (const CodeGenInstruction *CGI : NumberedInstructions) {
+      Record *R = CGI->TheDef;
+      if (R->getValueAsString("Namespace") == "TargetOpcode" ||
+          R->getValueAsBit("isPseudo"))
+        continue;
 
-    if (const RecordVal *RV = R->getValue("EncodingInfos")) {
-      if (DefInit *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
-        EncodingInfoByHwMode EBM(DI->getDef(), HWM);
-        for (auto &KV : EBM) {
-          BitsInit *BI = KV.second->getValueAsBitsInit("Inst");
-          BitWidth = std::max(BitWidth, BI->getNumBits());
-          HwModes.insert(KV.first);
+      if (const RecordVal *RV = R->getValue("EncodingInfos")) {
+        if (DefInit *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
+          EncodingInfoByHwMode EBM(DI->getDef(), HWM);
+          for (auto &KV : EBM) {
+            BitsInit *BI = KV.second->getValueAsBitsInit("Inst");
+            BitWidth = std::max(BitWidth, BI->getNumBits());
+            HwModes.insert(KV.first);
+          }
+          continue;
         }
-        continue;
       }
+      BitsInit *BI = R->getValueAsBitsInit("Inst");
+      BitWidth = std::max(BitWidth, BI->getNumBits());
     }
-    BitsInit *BI = R->getValueAsBitsInit("Inst");
-    BitWidth = std::max(BitWidth, BI->getNumBits());
-  }
-  UseAPInt = BitWidth > 64;
-  
-  // Emit function declaration
-  if (UseAPInt) {
-    o << "void " << Target.getName()
-      << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
-      << "    SmallVectorImpl<MCFixup> &Fixups,\n"
-      << "    APInt &Inst,\n"
-      << "    APInt &Scratch,\n"
-      << "    const MCSubtargetInfo &STI) const {\n";
-  } else {
-    o << "uint64_t " << Target.getName();
-    o << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
-      << "    SmallVectorImpl<MCFixup> &Fixups,\n"
-      << "    const MCSubtargetInfo &STI) const {\n";
-  }
-  
-  // Emit instruction base values
-  if (HwModes.empty()) {
-    emitInstructionBaseValues(o, NumberedInstructions, Target, -1);
-  } else {
-    for (unsigned HwMode : HwModes)
-      emitInstructionBaseValues(o, NumberedInstructions, Target, (int)HwMode);
-  }
+    UseAPInt = BitWidth > 64;
 
-  if (!HwModes.empty()) {
-    o << "  const uint64_t *InstBits;\n";
-    o << "  unsigned HwMode = STI.getHwMode();\n";
-    o << "  switch (HwMode) {\n";
-    o << "  default: llvm_unreachable(\"Unknown hardware mode!\"); break;\n";
-    for (unsigned I : HwModes) {
-      o << "  case " << I << ": InstBits = InstBits_" << HWM.getMode(I).Name
-        << "; break;\n";
+    // Emit function declaration
+    if (UseAPInt) {
+      o << "void " << Target.getName()
+        << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
+        << "    SmallVectorImpl<MCFixup> &Fixups,\n"
+        << "    APInt &Inst,\n"
+        << "    APInt &Scratch,\n"
+        << "    const MCSubtargetInfo &STI) const {\n";
+    } else {
+      o << "uint64_t " << Target.getName();
+      o << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
+        << "    SmallVectorImpl<MCFixup> &Fixups,\n"
+        << "    const MCSubtargetInfo &STI) const {\n";
     }
-    o << "  };\n";
-  }
 
-  // Map to accumulate all the cases.
-  std::map<std::string, std::vector<std::string>> CaseMap;
+    // Emit instruction base values
+    if (HwModes.empty()) {
+      emitInstructionBaseValues(o, NumberedInstructions, Target, -1);
+    } else {
+      for (unsigned HwMode : HwModes)
+        emitInstructionBaseValues(o, NumberedInstructions, Target, (int)HwMode);
+    }
 
-  // Construct all cases statement for each opcode
-  for (Record *R : Insts) {
-    if (R->getValueAsString("Namespace") == "TargetOpcode" ||
-        R->getValueAsBit("isPseudo"))
-      continue;
-    std::string InstName =
-        (R->getValueAsString("Namespace") + "::" + R->getName()).str();
-    std::string Case = getInstructionCase(R, Target);
+    if (!HwModes.empty()) {
+      o << "  const uint64_t *InstBits;\n";
+      o << "  unsigned HwMode = STI.getHwMode();\n";
+      o << "  switch (HwMode) {\n";
+      o << "  default: llvm_unreachable(\"Unknown hardware mode!\"); break;\n";
+      for (unsigned I : HwModes) {
+        o << "  case " << I << ": InstBits = InstBits_" << HWM.getMode(I).Name
+          << "; break;\n";
+      }
+      o << "  };\n";
+    }
 
-    CaseMap[Case].push_back(std::move(InstName));
-  }
+    // Map to accumulate all the cases.
+    std::map<std::string, std::vector<std::string>> CaseMap;
 
-  // Emit initial function code
-  if (UseAPInt) {
-    int NumWords = APInt::getNumWords(BitWidth);
-    int NumBytes = (BitWidth + 7) / 8;
-    o << "  const unsigned opcode = MI.getOpcode();\n"
-      << "  if (Inst.getBitWidth() != " << BitWidth << ")\n"
-      << "    Inst = Inst.zext(" << BitWidth << ");\n"
-      << "  if (Scratch.getBitWidth() != " << BitWidth << ")\n"
-      << "    Scratch = Scratch.zext(" << BitWidth << ");\n"
-      << "  LoadIntFromMemory(Inst, (const uint8_t *)&InstBits[opcode * "
-      << NumWords << "], " << NumBytes << ");\n"
-      << "  APInt &Value = Inst;\n"
-      << "  APInt &op = Scratch;\n"
-      << "  switch (opcode) {\n";
-  } else {
-    o << "  const unsigned opcode = MI.getOpcode();\n"
-      << "  uint64_t Value = InstBits[opcode];\n"
-      << "  uint64_t op = 0;\n"
-      << "  (void)op;  // suppress warning\n"
-      << "  switch (opcode) {\n";
-  }
+    // Construct all cases statement for each opcode
+    for (Record *R : Insts) {
+      if (R->getValueAsString("Namespace") == "TargetOpcode" ||
+          R->getValueAsBit("isPseudo"))
+        continue;
+      std::string InstName =
+          (R->getValueAsString("Namespace") + "::" + R->getName()).str();
+      std::string Case = getInstructionCase(R, Target);
+
+      CaseMap[Case].push_back(std::move(InstName));
+    }
+
+    // Emit initial function code
+    if (UseAPInt) {
+      int NumWords = APInt::getNumWords(BitWidth);
+      int NumBytes = (BitWidth + 7) / 8;
+      o << "  const unsigned opcode = MI.getOpcode();\n"
+        << "  if (Inst.getBitWidth() != " << BitWidth << ")\n"
+        << "    Inst = Inst.zext(" << BitWidth << ");\n"
+        << "  if (Scratch.getBitWidth() != " << BitWidth << ")\n"
+        << "    Scratch = Scratch.zext(" << BitWidth << ");\n"
+        << "  LoadIntFromMemory(Inst, (const uint8_t *)&InstBits[opcode * "
+        << NumWords << "], " << NumBytes << ");\n"
+        << "  APInt &Value = Inst;\n"
+        << "  APInt &op = Scratch;\n"
+        << "  switch (opcode) {\n";
+    } else {
+      o << "  const unsigned opcode = MI.getOpcode();\n"
+        << "  uint64_t Value = InstBits[opcode];\n"
+        << "  uint64_t op = 0;\n"
+        << "  (void)op;  // suppress warning\n"
+        << "  switch (opcode) {\n";
+    }
 
-  // Emit each case statement
-  std::map<std::string, std::vector<std::string>>::iterator IE, EE;
-  for (IE = CaseMap.begin(), EE = CaseMap.end(); IE != EE; ++IE) {
-    const std::string &Case = IE->first;
-    std::vector<std::string> &InstList = IE->second;
+    // Emit each case statement
+    std::map<std::string, std::vector<std::string>>::iterator IE, EE;
+    for (IE = CaseMap.begin(), EE = CaseMap.end(); IE != EE; ++IE) {
+      const std::string &Case = IE->first;
+      std::vector<std::string> &InstList = IE->second;
 
-    for (int i = 0, N = InstList.size(); i < N; i++) {
-      if (i) o << "\n";
-      o << "    case " << InstList[i]  << ":";
+      for (int i = 0, N = InstList.size(); i < N; i++) {
+        if (i)
+          o << "\n";
+        o << "    case " << InstList[i] << ":";
+      }
+      o << " {\n";
+      o << Case;
+      o << "      break;\n"
+        << "    }\n";
     }
-    o << " {\n";
-    o << Case;
-    o << "      break;\n"
-      << "    }\n";
-  }
 
-  // Default case: unhandled opcode
-  o << "  default:\n"
-    << "    std::string msg;\n"
-    << "    raw_string_ostream Msg(msg);\n"
-    << "    Msg << \"Not supported instr: \" << MI;\n"
-    << "    report_fatal_error(msg.c_str());\n"
-    << "  }\n";
-  if (UseAPInt)
-    o << "  Inst = Value;\n";
-  else
-    o << "  return Value;\n";
-  o << "}\n\n";
+    // Default case: unhandled opcode
+    o << "  default:\n"
+      << "    std::string msg;\n"
+      << "    raw_string_ostream Msg(msg);\n"
+      << "    Msg << \"Not supported instr: \" << MI;\n"
+      << "    report_fatal_error(Msg.str().c_str());\n"
+      << "  }\n";
+    if (UseAPInt)
+      o << "  Inst = Value;\n";
+    else
+      o << "  return Value;\n";
+    o << "}\n\n";
+  }
 
   const auto &All = SubtargetFeatureInfo::getAll(Records);
   std::map<Record *, SubtargetFeatureInfo, LessRecordByID> SubtargetFeatures;

diff  --git a/llvm/utils/TableGen/VarLenCodeEmitterGen.cpp b/llvm/utils/TableGen/VarLenCodeEmitterGen.cpp
new file mode 100644
index 0000000000000..832c9053ffb9b
--- /dev/null
+++ b/llvm/utils/TableGen/VarLenCodeEmitterGen.cpp
@@ -0,0 +1,491 @@
+//===- VarLenCodeEmitterGen.cpp - CEG for variable-length insts -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The CodeEmitterGen component for variable-length instructions.
+//
+// The basic CodeEmitterGen is almost exclusively designed for fixed-
+// length instructions. A good analogy for its encoding scheme is how printf
+// works: The (immutable) formatting string represent the fixed values in the
+// encoded instruction. Placeholders (i.e. %something), on the other hand,
+// represent encoding for instruction operands.
+// ```
+// printf("1101 %src 1001 %dst", <encoded value for operand `src`>,
+//                               <encoded value for operand `dst`>);
+// ```
+// VarLenCodeEmitterGen in this file provides an alternative encoding scheme
+// that works more like a C++ stream operator:
+// ```
+// OS << 0b1101;
+// if (Cond)
+//   OS << OperandEncoding0;
+// OS << 0b1001 << OperandEncoding1;
+// ```
+// You are free to concatenate arbitrary types (and sizes) of encoding
+// fragments on any bit position, bringing more flexibilities on defining
+// encoding for variable-length instructions.
+//
+// In a more specific way, instruction encoding is represented by a DAG type
+// `Inst` field. Here is an example:
+// ```
+// dag Inst = (descend 0b1101, (operand "$src", 4), 0b1001,
+//                     (operand "$dst", 4));
+// ```
+// It represents the following instruction encoding:
+// ```
+// MSB                                                     LSB
+// 1101<encoding for operand src>1001<encoding for operand dst>
+// ```
+// For more details about DAG operators in the above snippet, please
+// refer to \file include/llvm/Target/Target.td.
+//
+// VarLenCodeEmitter will convert the above DAG into the same helper function
+// generated by CodeEmitter, `MCCodeEmitter::getBinaryCodeForInstr` (except
+// for few details).
+//
+//===----------------------------------------------------------------------===//
+
+#include "VarLenCodeEmitterGen.h"
+#include "CodeGenInstruction.h"
+#include "CodeGenTarget.h"
+#include "SubtargetFeatureInfo.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace llvm;
+
+namespace {
+
+class VarLenCodeEmitterGen {
+  RecordKeeper &Records;
+
+  class VarLenInst {
+    size_t NumBits;
+
+    // Set if any of the segment is not fixed value.
+    bool HasDynamicSegment;
+
+    // {Number of bits, Value}
+    SmallVector<std::pair<unsigned, const Init *>, 4> Segments;
+
+    void buildRec(const DagInit *DI);
+
+  public:
+    VarLenInst() : NumBits(0U), HasDynamicSegment(false) {}
+
+    explicit VarLenInst(const DagInit *DI);
+
+    /// Number of bits
+    size_t size() const { return NumBits; }
+
+    using const_iterator = decltype(Segments)::const_iterator;
+
+    const_iterator begin() const { return Segments.begin(); }
+    const_iterator end() const { return Segments.end(); }
+    size_t getNumSegments() const { return Segments.size(); }
+
+    bool isFixedValueOnly() const { return !HasDynamicSegment; }
+  };
+
+  DenseMap<Record *, VarLenInst> VarLenInsts;
+
+  // Emit based values (i.e. fixed bits in the encoded instructions)
+  void emitInstructionBaseValues(
+      raw_ostream &OS,
+      ArrayRef<const CodeGenInstruction *> NumberedInstructions,
+      CodeGenTarget &Target, int HwMode = -1);
+
+  std::string getInstructionCase(Record *R, CodeGenTarget &Target);
+  std::string getInstructionCaseForEncoding(Record *R, Record *EncodingDef,
+                                            CodeGenTarget &Target);
+
+public:
+  explicit VarLenCodeEmitterGen(RecordKeeper &R) : Records(R) {}
+
+  void run(raw_ostream &OS);
+};
+
+} // end anonymous namespace
+
+VarLenCodeEmitterGen::VarLenInst::VarLenInst(const DagInit *DI) : NumBits(0U) {
+  buildRec(DI);
+  for (const auto &S : Segments)
+    NumBits += S.first;
+}
+
+void VarLenCodeEmitterGen::VarLenInst::buildRec(const DagInit *DI) {
+  std::string Op = DI->getOperator()->getAsString();
+
+  if (Op == "ascend" || Op == "descend") {
+    bool Reverse = Op == "descend";
+    int i = Reverse ? DI->getNumArgs() - 1 : 0;
+    int e = Reverse ? -1 : DI->getNumArgs();
+    int s = Reverse ? -1 : 1;
+    for (; i != e; i += s) {
+      const Init *Arg = DI->getArg(i);
+      if (const auto *BI = dyn_cast<BitsInit>(Arg)) {
+        if (!BI->isComplete())
+          PrintFatalError("Expecting complete bits init in `" + Op + "`");
+        Segments.push_back({BI->getNumBits(), BI});
+      } else if (const auto *BI = dyn_cast<BitInit>(Arg)) {
+        if (!BI->isConcrete())
+          PrintFatalError("Expecting concrete bit init in `" + Op + "`");
+        Segments.push_back({1, BI});
+      } else if (const auto *SubDI = dyn_cast<DagInit>(Arg)) {
+        buildRec(SubDI);
+      } else {
+        PrintFatalError("Unrecognized type of argument in `" + Op +
+                        "`: " + Arg->getAsString());
+      }
+    }
+  } else if (Op == "operand") {
+    // (operand <operand name>, <# of bits>)
+    if (DI->getNumArgs() != 2)
+      PrintFatalError("Expecting 2 arguments for `operand`");
+    HasDynamicSegment = true;
+    const Init *OperandName = DI->getArg(0), *NumBits = DI->getArg(1);
+    if (!isa<StringInit>(OperandName) || !isa<IntInit>(NumBits))
+      PrintFatalError("Invalid argument types for `operand`");
+
+    auto NumBitsVal = cast<IntInit>(NumBits)->getValue();
+    if (NumBitsVal <= 0)
+      PrintFatalError("Invalid number of bits for `operand`");
+
+    Segments.push_back({NumBitsVal, OperandName});
+  } else if (Op == "slice") {
+    // (slice <operand name>, <high / low bit>, <low / high bit>)
+    if (DI->getNumArgs() != 3)
+      PrintFatalError("Expecting 3 arguments for `slice`");
+    HasDynamicSegment = true;
+    Init *OperandName = DI->getArg(0), *HiBit = DI->getArg(1),
+         *LoBit = DI->getArg(2);
+    if (!isa<StringInit>(OperandName) || !isa<IntInit>(HiBit) ||
+        !isa<IntInit>(LoBit))
+      PrintFatalError("Invalid argument types for `slice`");
+
+    auto HiBitVal = cast<IntInit>(HiBit)->getValue(),
+         LoBitVal = cast<IntInit>(LoBit)->getValue();
+    if (HiBitVal < 0 || LoBitVal < 0)
+      PrintFatalError("Invalid bit range for `slice`");
+    bool NeedSwap = false;
+    unsigned NumBits = 0U;
+    if (HiBitVal < LoBitVal) {
+      NeedSwap = true;
+      NumBits = static_cast<unsigned>(LoBitVal - HiBitVal + 1);
+    } else {
+      NumBits = static_cast<unsigned>(HiBitVal - LoBitVal + 1);
+    }
+
+    if (NeedSwap) {
+      // Normalization: Hi bit should always be the second argument.
+      Init *const NewArgs[] = {OperandName, LoBit, HiBit};
+      Segments.push_back(
+          {NumBits, DagInit::get(DI->getOperator(), nullptr, NewArgs, {})});
+    } else {
+      Segments.push_back({NumBits, DI});
+    }
+  }
+}
+
+void VarLenCodeEmitterGen::run(raw_ostream &OS) {
+  CodeGenTarget Target(Records);
+  auto Insts = Records.getAllDerivedDefinitions("Instruction");
+
+  auto NumberedInstructions = Target.getInstructionsByEnumValue();
+  const CodeGenHwModes &HWM = Target.getHwModes();
+
+  // The set of HwModes used by instruction encodings.
+  std::set<unsigned> HwModes;
+  for (const CodeGenInstruction *CGI : NumberedInstructions) {
+    Record *R = CGI->TheDef;
+
+    // Create the corresponding VarLenInst instance.
+    if (R->getValueAsString("Namespace") == "TargetOpcode" ||
+        R->getValueAsBit("isPseudo"))
+      continue;
+
+    if (const RecordVal *RV = R->getValue("EncodingInfos")) {
+      if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
+        EncodingInfoByHwMode EBM(DI->getDef(), HWM);
+        for (auto &KV : EBM) {
+          HwModes.insert(KV.first);
+          Record *EncodingDef = KV.second;
+          auto *DI = EncodingDef->getValueAsDag("Inst");
+          VarLenInsts.insert({EncodingDef, VarLenInst(DI)});
+        }
+        continue;
+      }
+    }
+    auto *DI = R->getValueAsDag("Inst");
+    VarLenInsts.insert({R, VarLenInst(DI)});
+  }
+
+  // Emit function declaration
+  OS << "void " << Target.getName()
+     << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
+     << "    SmallVectorImpl<MCFixup> &Fixups,\n"
+     << "    APInt &Inst,\n"
+     << "    APInt &Scratch,\n"
+     << "    const MCSubtargetInfo &STI) const {\n";
+
+  // Emit instruction base values
+  if (HwModes.empty()) {
+    emitInstructionBaseValues(OS, NumberedInstructions, Target);
+  } else {
+    for (unsigned HwMode : HwModes)
+      emitInstructionBaseValues(OS, NumberedInstructions, Target, (int)HwMode);
+  }
+
+  if (!HwModes.empty()) {
+    OS << "  const unsigned **Index;\n";
+    OS << "  const uint64_t *InstBits;\n";
+    OS << "  unsigned HwMode = STI.getHwMode();\n";
+    OS << "  switch (HwMode) {\n";
+    OS << "  default: llvm_unreachable(\"Unknown hardware mode!\"); break;\n";
+    for (unsigned I : HwModes) {
+      OS << "  case " << I << ": InstBits = InstBits_" << HWM.getMode(I).Name
+         << "; Index = Index_" << HWM.getMode(I).Name << "; break;\n";
+    }
+    OS << "  };\n";
+  }
+
+  // Emit helper function to retrieve base values.
+  OS << "  auto getInstBits = [&](unsigned Opcode) -> APInt {\n"
+     << "    unsigned NumBits = Index[Opcode][0];\n"
+     << "    if (!NumBits)\n"
+     << "      return APInt::getZeroWidth();\n"
+     << "    unsigned Idx = Index[Opcode][1];\n"
+     << "    ArrayRef<uint64_t> Data(&InstBits[Idx], "
+     << "APInt::getNumWords(NumBits));\n"
+     << "    return APInt(NumBits, Data);\n"
+     << "  };\n";
+
+  // Map to accumulate all the cases.
+  std::map<std::string, std::vector<std::string>> CaseMap;
+
+  // Construct all cases statement for each opcode
+  for (Record *R : Insts) {
+    if (R->getValueAsString("Namespace") == "TargetOpcode" ||
+        R->getValueAsBit("isPseudo"))
+      continue;
+    std::string InstName =
+        (R->getValueAsString("Namespace") + "::" + R->getName()).str();
+    std::string Case = getInstructionCase(R, Target);
+
+    CaseMap[Case].push_back(std::move(InstName));
+  }
+
+  // Emit initial function code
+  OS << "  const unsigned opcode = MI.getOpcode();\n"
+     << "  switch (opcode) {\n";
+
+  // Emit each case statement
+  for (const auto &C : CaseMap) {
+    const std::string &Case = C.first;
+    const auto &InstList = C.second;
+
+    ListSeparator LS("\n");
+    for (const auto &InstName : InstList)
+      OS << LS << "    case " << InstName << ":";
+
+    OS << " {\n";
+    OS << Case;
+    OS << "      break;\n"
+       << "    }\n";
+  }
+  // Default case: unhandled opcode
+  OS << "  default:\n"
+     << "    std::string msg;\n"
+     << "    raw_string_ostream Msg(msg);\n"
+     << "    Msg << \"Not supported instr: \" << MI;\n"
+     << "    report_fatal_error(Msg.str().c_str());\n"
+     << "  }\n";
+  OS << "}\n\n";
+}
+
+static void emitInstBits(raw_ostream &IS, raw_ostream &SS, const APInt &Bits,
+                         unsigned &Index) {
+  if (!Bits.getNumWords()) {
+    IS.indent(4) << "{/*NumBits*/0, /*Index*/0},";
+    return;
+  }
+
+  IS.indent(4) << "{/*NumBits*/" << Bits.getBitWidth() << ", "
+               << "/*Index*/" << Index << "},";
+
+  SS.indent(4);
+  for (unsigned I = 0; I < Bits.getNumWords(); ++I, ++Index)
+    SS << "UINT64_C(" << utostr(Bits.getRawData()[I]) << "),";
+}
+
+void VarLenCodeEmitterGen::emitInstructionBaseValues(
+    raw_ostream &OS, ArrayRef<const CodeGenInstruction *> NumberedInstructions,
+    CodeGenTarget &Target, int HwMode) {
+  std::string IndexArray, StorageArray;
+  raw_string_ostream IS(IndexArray), SS(StorageArray);
+
+  const CodeGenHwModes &HWM = Target.getHwModes();
+  if (HwMode == -1) {
+    IS << "  static const unsigned Index[][2] = {\n";
+    SS << "  static const uint64_t InstBits[] = {\n";
+  } else {
+    StringRef Name = HWM.getMode(HwMode).Name;
+    IS << "  static const unsigned Index_" << Name << "[][2] = {\n";
+    SS << "  static const uint64_t InstBits_" << Name << "[] = {\n";
+  }
+
+  unsigned NumFixedValueWords = 0U;
+  for (const CodeGenInstruction *CGI : NumberedInstructions) {
+    Record *R = CGI->TheDef;
+
+    if (R->getValueAsString("Namespace") == "TargetOpcode" ||
+        R->getValueAsBit("isPseudo")) {
+      IS.indent(4) << "{/*NumBits*/0, /*Index*/0},\n";
+      continue;
+    }
+
+    Record *EncodingDef = R;
+    if (const RecordVal *RV = R->getValue("EncodingInfos")) {
+      if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
+        EncodingInfoByHwMode EBM(DI->getDef(), HWM);
+        if (EBM.hasMode(HwMode))
+          EncodingDef = EBM.get(HwMode);
+      }
+    }
+
+    auto It = VarLenInsts.find(EncodingDef);
+    if (It == VarLenInsts.end())
+      PrintFatalError(EncodingDef, "VarLenInst not found for this record");
+    const VarLenInst &VLI = It->second;
+
+    unsigned i = 0U, BitWidth = VLI.size();
+
+    // Start by filling in fixed values.
+    APInt Value(BitWidth, 0);
+    auto SI = VLI.begin(), SE = VLI.end();
+    // Scan through all the segments that have fixed-bits values.
+    while (i < BitWidth && SI != SE) {
+      unsigned SegmentNumBits = SI->first;
+      if (const auto *BI = dyn_cast<BitsInit>(SI->second)) {
+        for (unsigned Idx = 0U; Idx != SegmentNumBits; ++Idx) {
+          auto *B = cast<BitInit>(BI->getBit(Idx));
+          Value.setBitVal(i + Idx, B->getValue());
+        }
+      }
+      if (const auto *BI = dyn_cast<BitInit>(SI->second))
+        Value.setBitVal(i, BI->getValue());
+
+      i += SegmentNumBits;
+      ++SI;
+    }
+
+    emitInstBits(IS, SS, Value, NumFixedValueWords);
+    IS << '\t' << "// " << R->getName() << "\n";
+    if (Value.getNumWords())
+      SS << '\t' << "// " << R->getName() << "\n";
+  }
+  IS.indent(4) << "{/*NumBits*/0, /*Index*/0}\n  };\n";
+  SS.indent(4) << "UINT64_C(0)\n  };\n";
+
+  OS << IS.str() << SS.str();
+}
+
+std::string VarLenCodeEmitterGen::getInstructionCase(Record *R,
+                                                     CodeGenTarget &Target) {
+  std::string Case;
+  if (const RecordVal *RV = R->getValue("EncodingInfos")) {
+    if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
+      const CodeGenHwModes &HWM = Target.getHwModes();
+      EncodingInfoByHwMode EBM(DI->getDef(), HWM);
+      Case += "      switch (HwMode) {\n";
+      Case += "      default: llvm_unreachable(\"Unhandled HwMode\");\n";
+      for (auto &KV : EBM) {
+        Case += "      case " + itostr(KV.first) + ": {\n";
+        Case += getInstructionCaseForEncoding(R, KV.second, Target);
+        Case += "      break;\n";
+        Case += "      }\n";
+      }
+      Case += "      }\n";
+      return Case;
+    }
+  }
+  return getInstructionCaseForEncoding(R, R, Target);
+}
+
+std::string VarLenCodeEmitterGen::getInstructionCaseForEncoding(
+    Record *R, Record *EncodingDef, CodeGenTarget &Target) {
+  auto It = VarLenInsts.find(EncodingDef);
+  if (It == VarLenInsts.end())
+    PrintFatalError(EncodingDef, "Parsed encoding record not found");
+  const VarLenInst &VLI = It->second;
+  size_t BitWidth = VLI.size();
+
+  CodeGenInstruction &CGI = Target.getInstruction(R);
+
+  std::string Case;
+  raw_string_ostream SS(Case);
+  // Resize the scratch buffer.
+  if (BitWidth && !VLI.isFixedValueOnly())
+    SS.indent(6) << "Scratch = Scratch.zextOrSelf(" << BitWidth << ");\n";
+  // Populate based value.
+  SS.indent(6) << "Inst = getInstBits(opcode);\n";
+
+  // Process each segment in VLI.
+  size_t Offset = 0U;
+  for (const auto &Pair : VLI) {
+    unsigned NumBits = Pair.first;
+    const Init *Val = Pair.second;
+    // If it's a StringInit or DagInit, it's a reference to an operand
+    // or part of an operand.
+    if (isa<StringInit>(Val) || isa<DagInit>(Val)) {
+      StringRef OperandName;
+      unsigned LoBit = 0U;
+      if (const auto *SV = dyn_cast<StringInit>(Val)) {
+        OperandName = SV->getValue();
+      } else {
+        // Normalized: (slice <operand name>, <high bit>, <low bit>)
+        const auto *DV = cast<DagInit>(Val);
+        OperandName = cast<StringInit>(DV->getArg(0))->getValue();
+        LoBit = static_cast<unsigned>(cast<IntInit>(DV->getArg(2))->getValue());
+      }
+
+      auto OpIdx = CGI.Operands.ParseOperandName(OperandName);
+      unsigned FlatOpIdx = CGI.Operands.getFlattenedOperandNumber(OpIdx);
+      StringRef EncoderMethodName = "getMachineOpValue";
+      auto &CustomEncoder = CGI.Operands[OpIdx.first].EncoderMethodName;
+      if (!CustomEncoder.empty())
+        EncoderMethodName = CustomEncoder;
+
+      SS.indent(6) << "Scratch.clearAllBits();\n";
+      SS.indent(6) << "// op: " << OperandName.drop_front(1) << "\n";
+      SS.indent(6) << EncoderMethodName << "(MI, MI.getOperand("
+                   << utostr(FlatOpIdx) << "), Scratch, Fixups, STI);\n";
+      SS.indent(6) << "Inst.insertBits("
+                   << "Scratch.extractBits(" << utostr(NumBits) << ", "
+                   << utostr(LoBit) << ")"
+                   << ", " << Offset << ");\n";
+    }
+    Offset += NumBits;
+  }
+
+  StringRef PostEmitter = R->getValueAsString("PostEncoderMethod");
+  if (!PostEmitter.empty())
+    SS.indent(6) << "Inst = " << PostEmitter << "(MI, Inst, STI);\n";
+
+  return Case;
+}
+
+namespace llvm {
+
+void emitVarLenCodeEmitter(RecordKeeper &R, raw_ostream &OS) {
+  VarLenCodeEmitterGen(R).run(OS);
+}
+
+} // end namespace llvm

diff  --git a/llvm/utils/TableGen/VarLenCodeEmitterGen.h b/llvm/utils/TableGen/VarLenCodeEmitterGen.h
new file mode 100644
index 0000000000000..330b791b7cce1
--- /dev/null
+++ b/llvm/utils/TableGen/VarLenCodeEmitterGen.h
@@ -0,0 +1,25 @@
+//===- VarLenCodeEmitterGen.h - CEG for variable-length insts ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declare the CodeEmitterGen component for variable-length
+// instructions. See the .cpp file for more details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_UTILS_TABLEGEN_VARLENCODEEMITTERGEN_H
+#define LLVM_UTILS_TABLEGEN_VARLENCODEEMITTERGEN_H
+
+namespace llvm {
+
+class RecordKeeper;
+class raw_ostream;
+
+void emitVarLenCodeEmitter(RecordKeeper &R, raw_ostream &OS);
+
+} // end namespace llvm
+#endif


        


More information about the llvm-commits mailing list