[llvm] f8325f1 - [Tablegen] Bugfix and refactor VarLenCodeEmitter HwModes. (#68795)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 19 22:21:28 PDT 2023


Author: Erik Jonsson
Date: 2023-10-20T07:21:24+02:00
New Revision: f8325f12606d7c8510abbf933bf95983bf66da7d

URL: https://github.com/llvm/llvm-project/commit/f8325f12606d7c8510abbf933bf95983bf66da7d
DIFF: https://github.com/llvm/llvm-project/commit/f8325f12606d7c8510abbf933bf95983bf66da7d.diff

LOG: [Tablegen] Bugfix and refactor VarLenCodeEmitter HwModes. (#68795)

VarLenCodeEmitterGen produced code that did not compile if using
alternative encoding in different HwModes. It's not possbile to assign

    unsigned **Index = Index_<mode>[][2] = { ... };

As a fix, Index and InstBits where removed in favor of mode specific
getInstBits_<mode> functions since this is the only place the arrays are
accessed.

Handling of HwModes is now concentrated to the VarLenCodeEmitterGen::run
method reducing the overall amount of code and enabling other types of
alternative encodings not related to HwModes.

Added a test for VarLenCodeEmitterGen HwModes.

Make sure that HwModes are supported in the same way they are supported
for the standard CodeEmitter. It should be possible to define
instructions with universal encoding across modes, distinct encodings
for each mode or only define encodings for some modes.

Fixed indentation in generated code.

Added: 
    llvm/test/TableGen/VarLenEncoderHwModes.td

Modified: 
    llvm/utils/TableGen/VarLenCodeEmitterGen.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/test/TableGen/VarLenEncoderHwModes.td b/llvm/test/TableGen/VarLenEncoderHwModes.td
new file mode 100644
index 000000000000000..e0da0c9b93df618
--- /dev/null
+++ b/llvm/test/TableGen/VarLenEncoderHwModes.td
@@ -0,0 +1,110 @@
+// RUN: llvm-tblgen -gen-emitter -I %p/../../include %s | FileCheck %s
+
+// Verify VarLenCodeEmitterGen using EncodingInfos with 
diff erent HwModes.
+
+include "llvm/Target/Target.td"
+
+def ArchInstrInfo : InstrInfo { }
+
+def Arch : Target {
+  let InstructionSet = ArchInstrInfo;
+}
+
+def Reg : Register<"reg">;
+
+def RegClass : RegisterClass<"foo", [i64], 0, (add Reg)>;
+
+def GR64 : RegisterOperand<RegClass>;
+
+def HasA : Predicate<"Subtarget->hasA()">;
+def HasB : Predicate<"Subtarget->hasB()">;
+
+def ModeA : HwMode<"+a", [HasA]>;
+def ModeB : HwMode<"+b", [HasB]>;
+
+def fooTypeEncA : InstructionEncoding {
+  dag Inst = (descend
+    (operand "$src", 4),
+    (operand "$dst", 4),
+    0b00000001
+  );
+}
+
+def fooTypeEncB : InstructionEncoding {
+  dag Inst = (descend
+    (operand "$dst", 4),
+    (operand "$src", 4),
+    0b00000010
+  );
+}
+
+def fooTypeEncC : InstructionEncoding {
+  dag Inst = (descend
+    (operand "$dst", 4),
+    (operand "$src", 4),
+    0b00000100
+  );
+}
+
+class VarLenInst : Instruction {
+  let AsmString = "foo $src, $dst";
+  let OutOperandList = (outs GR64:$dst);
+  let InOperandList  = (ins GR64:$src);
+}
+
+// Defined in both HwModes
+def foo : VarLenInst {
+  let EncodingInfos = EncodingByHwMode<
+    [ModeA, ModeB],
+    [fooTypeEncA, fooTypeEncB]
+  >;
+}
+
+// Same encoding in any HwMode
+def bar : VarLenInst {
+  dag Inst = (descend
+    (operand "$dst", 4),
+    (operand "$src", 4),
+    0b00000011
+  );
+}
+
+// Only defined in HwMode B.
+def baz : VarLenInst {
+  let EncodingInfos = EncodingByHwMode<
+    [ModeB],
+    [fooTypeEncC]
+  >;
+}
+
+// CHECK:     static const uint64_t InstBits_ModeA[] = {
+// CHECK:       UINT64_C(3),        // bar
+// CHECK:       UINT64_C(1),        // foo
+
+// CHECK:     static const uint64_t InstBits_ModeB[] = {
+// CHECK:       UINT64_C(3),        // bar
+// CHECK:       UINT64_C(4),        // baz
+// CHECK:       UINT64_C(2),        // foo
+
+// CHECK:     auto getInstBits_ModeA =
+// CHECK:       Idx = Index_ModeA
+
+// CHECK:     auto getInstBits_ModeB =
+// CHECK:       Idx = Index_ModeB
+
+// CHECK:     case ::bar: {
+// CHECK-NOT:   switch (Mode) {
+// CHECK:       Inst = getInstBits_ModeA
+
+// CHECK:     case ::foo: {
+// CHECK:       switch (Mode) {
+// CHECK:       case 1: {
+// CHECK:       Inst = getInstBits_ModeA
+// CHECK:       case 2: {
+// CHECK:       Inst = getInstBits_ModeB
+
+// CHECK:     case ::baz: {
+// CHECK:       case 1: {
+// CHECK:       llvm_unreachable("Undefined encoding in this mode");
+// CHECK:       case 2: {
+// CHECK:       Inst = getInstBits_ModeB

diff  --git a/llvm/utils/TableGen/VarLenCodeEmitterGen.cpp b/llvm/utils/TableGen/VarLenCodeEmitterGen.cpp
index 7a24030e17d8a82..24f116bbeaced5f 100644
--- a/llvm/utils/TableGen/VarLenCodeEmitterGen.cpp
+++ b/llvm/utils/TableGen/VarLenCodeEmitterGen.cpp
@@ -67,17 +67,26 @@ namespace {
 class VarLenCodeEmitterGen {
   RecordKeeper &Records;
 
-  DenseMap<Record *, VarLenInst> VarLenInsts;
+  // Representaton of alternative encodings used for HwModes.
+  using AltEncodingTy = int;
+  // Mode identifier when only one encoding is defined.
+  const AltEncodingTy Universal = -1;
+  // The set of alternative instruction encodings with a descriptive
+  // name suffix to improve readability of the generated code.
+  std::map<AltEncodingTy, std::string> Modes;
+
+  DenseMap<Record *, DenseMap<AltEncodingTy, VarLenInst>> VarLenInsts;
 
   // Emit based values (i.e. fixed bits in the encoded instructions)
   void emitInstructionBaseValues(
       raw_ostream &OS,
       ArrayRef<const CodeGenInstruction *> NumberedInstructions,
-      CodeGenTarget &Target, int HwMode = -1);
+      CodeGenTarget &Target, AltEncodingTy Mode);
 
-  std::string getInstructionCase(Record *R, CodeGenTarget &Target);
-  std::string getInstructionCaseForEncoding(Record *R, Record *EncodingDef,
-                                            CodeGenTarget &Target);
+  std::string getInstructionCases(Record *R, CodeGenTarget &Target);
+  std::string getInstructionCaseForEncoding(Record *R, AltEncodingTy Mode,
+                                            const VarLenInst &VLI,
+                                            CodeGenTarget &Target, int I);
 
 public:
   explicit VarLenCodeEmitterGen(RecordKeeper &R) : Records(R) {}
@@ -214,36 +223,38 @@ void VarLenCodeEmitterGen::run(raw_ostream &OS) {
   auto Insts = Records.getAllDerivedDefinitions("Instruction");
 
   auto NumberedInstructions = Target.getInstructionsByEnumValue();
-  const CodeGenHwModes &HWM = Target.getHwModes();
 
-  // The set of HwModes used by instruction encodings.
-  std::set<unsigned> HwModes;
   for (const CodeGenInstruction *CGI : NumberedInstructions) {
     Record *R = CGI->TheDef;
-
     // Create the corresponding VarLenInst instance.
     if (R->getValueAsString("Namespace") == "TargetOpcode" ||
         R->getValueAsBit("isPseudo"))
       continue;
 
+    // Setup alternative encodings according to HwModes
     if (const RecordVal *RV = R->getValue("EncodingInfos")) {
       if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
+        const CodeGenHwModes &HWM = Target.getHwModes();
         EncodingInfoByHwMode EBM(DI->getDef(), HWM);
         for (auto &KV : EBM) {
-          HwModes.insert(KV.first);
+          AltEncodingTy Mode = KV.first;
+          Modes.insert({Mode, "_" + HWM.getMode(Mode).Name.str()});
           Record *EncodingDef = KV.second;
           RecordVal *RV = EncodingDef->getValue("Inst");
           DagInit *DI = cast<DagInit>(RV->getValue());
-          VarLenInsts.insert({EncodingDef, VarLenInst(DI, RV)});
+          VarLenInsts[R].insert({Mode, VarLenInst(DI, RV)});
         }
         continue;
       }
     }
     RecordVal *RV = R->getValue("Inst");
     DagInit *DI = cast<DagInit>(RV->getValue());
-    VarLenInsts.insert({R, VarLenInst(DI, RV)});
+    VarLenInsts[R].insert({Universal, VarLenInst(DI, RV)});
   }
 
+  if (Modes.empty())
+    Modes.insert({Universal, ""}); // Base case, skip suffix.
+
   // Emit function declaration
   OS << "void " << Target.getName()
      << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
@@ -253,36 +264,26 @@ void VarLenCodeEmitterGen::run(raw_ostream &OS) {
      << "    const MCSubtargetInfo &STI) const {\n";
 
   // Emit instruction base values
-  if (HwModes.empty()) {
-    emitInstructionBaseValues(OS, NumberedInstructions, Target);
-  } else {
-    for (unsigned HwMode : HwModes)
-      emitInstructionBaseValues(OS, NumberedInstructions, Target, (int)HwMode);
-  }
+  for (const auto &Mode : Modes)
+    emitInstructionBaseValues(OS, NumberedInstructions, Target, Mode.first);
 
-  if (!HwModes.empty()) {
-    OS << "  const unsigned **Index;\n";
-    OS << "  const uint64_t *InstBits;\n";
-    OS << "  unsigned HwMode = STI.getHwMode();\n";
-    OS << "  switch (HwMode) {\n";
-    OS << "  default: llvm_unreachable(\"Unknown hardware mode!\"); break;\n";
-    for (unsigned I : HwModes) {
-      OS << "  case " << I << ": InstBits = InstBits_" << HWM.getMode(I).Name
-         << "; Index = Index_" << HWM.getMode(I).Name << "; break;\n";
-    }
-    OS << "  };\n";
+  if (Modes.size() > 1) {
+    OS << "  unsigned Mode = STI.getHwMode();\n";
   }
 
-  // Emit helper function to retrieve base values.
-  OS << "  auto getInstBits = [&](unsigned Opcode) -> APInt {\n"
-     << "    unsigned NumBits = Index[Opcode][0];\n"
-     << "    if (!NumBits)\n"
-     << "      return APInt::getZeroWidth();\n"
-     << "    unsigned Idx = Index[Opcode][1];\n"
-     << "    ArrayRef<uint64_t> Data(&InstBits[Idx], "
-     << "APInt::getNumWords(NumBits));\n"
-     << "    return APInt(NumBits, Data);\n"
-     << "  };\n";
+  for (const auto &Mode : Modes) {
+    // Emit helper function to retrieve base values.
+    OS << "  auto getInstBits" << Mode.second
+       << " = [&](unsigned Opcode) -> APInt {\n"
+       << "    unsigned NumBits = Index" << Mode.second << "[Opcode][0];\n"
+       << "    if (!NumBits)\n"
+       << "      return APInt::getZeroWidth();\n"
+       << "    unsigned Idx = Index" << Mode.second << "[Opcode][1];\n"
+       << "    ArrayRef<uint64_t> Data(&InstBits" << Mode.second << "[Idx], "
+       << "APInt::getNumWords(NumBits));\n"
+       << "    return APInt(NumBits, Data);\n"
+       << "  };\n";
+  }
 
   // Map to accumulate all the cases.
   std::map<std::string, std::vector<std::string>> CaseMap;
@@ -294,7 +295,7 @@ void VarLenCodeEmitterGen::run(raw_ostream &OS) {
       continue;
     std::string InstName =
         (R->getValueAsString("Namespace") + "::" + R->getName()).str();
-    std::string Case = getInstructionCase(R, Target);
+    std::string Case = getInstructionCases(R, Target);
 
     CaseMap[Case].push_back(std::move(InstName));
   }
@@ -344,19 +345,12 @@ static void emitInstBits(raw_ostream &IS, raw_ostream &SS, const APInt &Bits,
 
 void VarLenCodeEmitterGen::emitInstructionBaseValues(
     raw_ostream &OS, ArrayRef<const CodeGenInstruction *> NumberedInstructions,
-    CodeGenTarget &Target, int HwMode) {
+    CodeGenTarget &Target, AltEncodingTy Mode) {
   std::string IndexArray, StorageArray;
   raw_string_ostream IS(IndexArray), SS(StorageArray);
 
-  const CodeGenHwModes &HWM = Target.getHwModes();
-  if (HwMode == -1) {
-    IS << "  static const unsigned Index[][2] = {\n";
-    SS << "  static const uint64_t InstBits[] = {\n";
-  } else {
-    StringRef Name = HWM.getMode(HwMode).Name;
-    IS << "  static const unsigned Index_" << Name << "[][2] = {\n";
-    SS << "  static const uint64_t InstBits_" << Name << "[] = {\n";
-  }
+  IS << "  static const unsigned Index" << Modes[Mode] << "[][2] = {\n";
+  SS << "  static const uint64_t InstBits" << Modes[Mode] << "[] = {\n";
 
   unsigned NumFixedValueWords = 0U;
   for (const CodeGenInstruction *CGI : NumberedInstructions) {
@@ -368,20 +362,18 @@ void VarLenCodeEmitterGen::emitInstructionBaseValues(
       continue;
     }
 
-    Record *EncodingDef = R;
-    if (const RecordVal *RV = R->getValue("EncodingInfos")) {
-      if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
-        EncodingInfoByHwMode EBM(DI->getDef(), HWM);
-        if (EBM.hasMode(HwMode))
-          EncodingDef = EBM.get(HwMode);
-      }
+    const auto InstIt = VarLenInsts.find(R);
+    if (InstIt == VarLenInsts.end())
+      PrintFatalError(R, "VarLenInst not found for this record");
+    auto ModeIt = InstIt->second.find(Mode);
+    if (ModeIt == InstIt->second.end())
+      ModeIt = InstIt->second.find(Universal);
+    if (ModeIt == InstIt->second.end()) {
+      IS.indent(4) << "{/*NumBits*/0, /*Index*/0},\t"
+                   << "// " << R->getName() << " no encoding\n";
+      continue;
     }
-
-    auto It = VarLenInsts.find(EncodingDef);
-    if (It == VarLenInsts.end())
-      PrintFatalError(EncodingDef, "VarLenInst not found for this record");
-    const VarLenInst &VLI = It->second;
-
+    const VarLenInst &VLI = ModeIt->second;
     unsigned i = 0U, BitWidth = VLI.size();
 
     // Start by filling in fixed values.
@@ -414,34 +406,45 @@ void VarLenCodeEmitterGen::emitInstructionBaseValues(
   OS << IS.str() << SS.str();
 }
 
-std::string VarLenCodeEmitterGen::getInstructionCase(Record *R,
-                                                     CodeGenTarget &Target) {
+std::string VarLenCodeEmitterGen::getInstructionCases(Record *R,
+                                                      CodeGenTarget &Target) {
+  auto It = VarLenInsts.find(R);
+  if (It == VarLenInsts.end())
+    PrintFatalError(R, "Parsed encoding record not found");
+  const auto &Map = It->second;
+
+  // Is this instructions encoding universal (same for all modes)?
+  // Allways true if there is only one mode.
+  if (Map.size() == 1 && Map.begin()->first == Universal) {
+    // Universal, just pick the first mode.
+    AltEncodingTy Mode = Modes.begin()->first;
+    const auto &Encoding = Map.begin()->second;
+    return getInstructionCaseForEncoding(R, Mode, Encoding, Target, 6);
+  }
+
   std::string Case;
-  if (const RecordVal *RV = R->getValue("EncodingInfos")) {
-    if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
-      const CodeGenHwModes &HWM = Target.getHwModes();
-      EncodingInfoByHwMode EBM(DI->getDef(), HWM);
-      Case += "      switch (HwMode) {\n";
-      Case += "      default: llvm_unreachable(\"Unhandled HwMode\");\n";
-      for (auto &KV : EBM) {
-        Case += "      case " + itostr(KV.first) + ": {\n";
-        Case += getInstructionCaseForEncoding(R, KV.second, Target);
-        Case += "      break;\n";
-        Case += "      }\n";
-      }
-      Case += "      }\n";
-      return Case;
+  Case += "      switch (Mode) {\n";
+  Case += "      default: llvm_unreachable(\"Unhandled Mode\");\n";
+  for (const auto &Mode : Modes) {
+    Case += "      case " + itostr(Mode.first) + ": {\n";
+    const auto &It = Map.find(Mode.first);
+    if (It == Map.end()) {
+      Case +=
+          "        llvm_unreachable(\"Undefined encoding in this mode\");\n";
+    } else {
+      Case +=
+          getInstructionCaseForEncoding(R, It->first, It->second, Target, 8);
     }
+    Case += "        break;\n";
+    Case += "      }\n";
   }
-  return getInstructionCaseForEncoding(R, R, Target);
+  Case += "      }\n";
+  return Case;
 }
 
 std::string VarLenCodeEmitterGen::getInstructionCaseForEncoding(
-    Record *R, Record *EncodingDef, CodeGenTarget &Target) {
-  auto It = VarLenInsts.find(EncodingDef);
-  if (It == VarLenInsts.end())
-    PrintFatalError(EncodingDef, "Parsed encoding record not found");
-  const VarLenInst &VLI = It->second;
+    Record *R, AltEncodingTy Mode, const VarLenInst &VLI, CodeGenTarget &Target,
+    int I) {
   size_t BitWidth = VLI.size();
 
   CodeGenInstruction &CGI = Target.getInstruction(R);
@@ -450,9 +453,9 @@ std::string VarLenCodeEmitterGen::getInstructionCaseForEncoding(
   raw_string_ostream SS(Case);
   // Resize the scratch buffer.
   if (BitWidth && !VLI.isFixedValueOnly())
-    SS.indent(6) << "Scratch = Scratch.zext(" << BitWidth << ");\n";
+    SS.indent(I) << "Scratch = Scratch.zext(" << BitWidth << ");\n";
   // Populate based value.
-  SS.indent(6) << "Inst = getInstBits(opcode);\n";
+  SS.indent(I) << "Inst = getInstBits" << Modes[Mode] << "(opcode);\n";
 
   // Process each segment in VLI.
   size_t Offset = 0U;
@@ -480,17 +483,17 @@ std::string VarLenCodeEmitterGen::getInstructionCaseForEncoding(
       if (ES.CustomEncoder.size())
         CustomEncoder = ES.CustomEncoder;
 
-      SS.indent(6) << "Scratch.clearAllBits();\n";
-      SS.indent(6) << "// op: " << OperandName.drop_front(1) << "\n";
+      SS.indent(I) << "Scratch.clearAllBits();\n";
+      SS.indent(I) << "// op: " << OperandName.drop_front(1) << "\n";
       if (CustomEncoder.empty())
-        SS.indent(6) << "getMachineOpValue(MI, MI.getOperand("
+        SS.indent(I) << "getMachineOpValue(MI, MI.getOperand("
                      << utostr(FlatOpIdx) << ")";
       else
-        SS.indent(6) << CustomEncoder << "(MI, /*OpIdx=*/" << utostr(FlatOpIdx);
+        SS.indent(I) << CustomEncoder << "(MI, /*OpIdx=*/" << utostr(FlatOpIdx);
 
       SS << ", /*Pos=*/" << utostr(Offset) << ", Scratch, Fixups, STI);\n";
 
-      SS.indent(6) << "Inst.insertBits("
+      SS.indent(I) << "Inst.insertBits("
                    << "Scratch.extractBits(" << utostr(NumBits) << ", "
                    << utostr(LoBit) << ")"
                    << ", " << Offset << ");\n";
@@ -500,7 +503,7 @@ std::string VarLenCodeEmitterGen::getInstructionCaseForEncoding(
 
   StringRef PostEmitter = R->getValueAsString("PostEncoderMethod");
   if (!PostEmitter.empty())
-    SS.indent(6) << "Inst = " << PostEmitter << "(MI, Inst, STI);\n";
+    SS.indent(I) << "Inst = " << PostEmitter << "(MI, Inst, STI);\n";
 
   return Case;
 }


        


More information about the llvm-commits mailing list