[llvm] f3d2a31 - [X86][CodeGen] Cleanup code for EVEX2VEX pass, NFCI

Shengchen Kan via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 27 21:11:53 PST 2023


Author: Shengchen Kan
Date: 2023-11-28T13:11:15+08:00
New Revision: f3d2a31d7d433f9c843a61caa7b025f3b7188ddf

URL: https://github.com/llvm/llvm-project/commit/f3d2a31d7d433f9c843a61caa7b025f3b7188ddf
DIFF: https://github.com/llvm/llvm-project/commit/f3d2a31d7d433f9c843a61caa7b025f3b7188ddf.diff

LOG: [X86][CodeGen] Cleanup code for EVEX2VEX pass, NFCI

1. Remove unused variables, e.g X86Subtarget object in performCustomAdjustments
2. Define checkVEXInstPredicate directly instead of generating it b/c
   the function is small and it's unlikely we have more instructions to
   check the predicate in the future
3. Check the tables are sorted only once for each function
4. Remove some blanks and clang-format code

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86EvexToVex.cpp
    llvm/lib/Target/X86/X86InstrFormats.td
    llvm/lib/Target/X86/X86InstrSSE.td
    llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86EvexToVex.cpp b/llvm/lib/Target/X86/X86EvexToVex.cpp
index fda6c15fed34db7..c425c37b4186812 100644
--- a/llvm/lib/Target/X86/X86EvexToVex.cpp
+++ b/llvm/lib/Target/X86/X86EvexToVex.cpp
@@ -12,9 +12,10 @@
 /// are encoded using the EVEX prefix and if possible replaces them by their
 /// corresponding VEX encoding which is usually shorter by 2 bytes.
 /// EVEX instructions may be encoded via the VEX prefix when the AVX-512
-/// instruction has a corresponding AVX/AVX2 opcode, when vector length 
-/// accessed by instruction is less than 512 bits and when it does not use 
-//  the xmm or the mask registers or xmm/ymm registers with indexes higher than 15.
+/// instruction has a corresponding AVX/AVX2 opcode, when vector length
+/// accessed by instruction is less than 512 bits and when it does not use
+//  the xmm or the mask registers or xmm/ymm registers with indexes higher
+//  than 15.
 /// The pass applies code reduction on the generated code for AVX-512 instrs.
 //
 //===----------------------------------------------------------------------===//
@@ -39,16 +40,16 @@ using namespace llvm;
 
 // Including the generated EVEX2VEX tables.
 struct X86EvexToVexCompressTableEntry {
-  uint16_t EvexOpcode;
-  uint16_t VexOpcode;
+  uint16_t EvexOpc;
+  uint16_t VexOpc;
 
   bool operator<(const X86EvexToVexCompressTableEntry &RHS) const {
-    return EvexOpcode < RHS.EvexOpcode;
+    return EvexOpc < RHS.EvexOpc;
   }
 
   friend bool operator<(const X86EvexToVexCompressTableEntry &TE,
                         unsigned Opc) {
-    return TE.EvexOpcode < Opc;
+    return TE.EvexOpc < Opc;
   }
 };
 #include "X86GenEVEX2VEXTables.inc"
@@ -61,16 +62,9 @@ struct X86EvexToVexCompressTableEntry {
 namespace {
 
 class EvexToVexInstPass : public MachineFunctionPass {
-
-  /// For EVEX instructions that can be encoded using VEX encoding, replace
-  /// them by the VEX encoding in order to reduce size.
-  bool CompressEvexToVexImpl(MachineInstr &MI) const;
-
 public:
   static char ID;
-
-  EvexToVexInstPass() : MachineFunctionPass(ID) { }
-
+  EvexToVexInstPass() : MachineFunctionPass(ID) {}
   StringRef getPassName() const override { return EVEX2VEX_DESC; }
 
   /// Loop over all of the basic blocks, replacing EVEX instructions
@@ -82,53 +76,23 @@ class EvexToVexInstPass : public MachineFunctionPass {
     return MachineFunctionProperties().set(
         MachineFunctionProperties::Property::NoVRegs);
   }
-
-private:
-  /// Machine instruction info used throughout the class.
-  const X86InstrInfo *TII = nullptr;
-
-  const X86Subtarget *ST = nullptr;
 };
 
 } // end anonymous namespace
 
 char EvexToVexInstPass::ID = 0;
 
-bool EvexToVexInstPass::runOnMachineFunction(MachineFunction &MF) {
-  TII = MF.getSubtarget<X86Subtarget>().getInstrInfo();
-
-  ST = &MF.getSubtarget<X86Subtarget>();
-  if (!ST->hasAVX512())
-    return false;
-
-  bool Changed = false;
-
-  /// Go over all basic blocks in function and replace
-  /// EVEX encoded instrs by VEX encoding when possible.
-  for (MachineBasicBlock &MBB : MF) {
-
-    // Traverse the basic block.
-    for (MachineInstr &MI : MBB)
-      Changed |= CompressEvexToVexImpl(MI);
-  }
-
-  return Changed;
-}
-
 static bool usesExtendedRegister(const MachineInstr &MI) {
   auto isHiRegIdx = [](unsigned Reg) {
     // Check for XMM register with indexes between 16 - 31.
     if (Reg >= X86::XMM16 && Reg <= X86::XMM31)
       return true;
-
     // Check for YMM register with indexes between 16 - 31.
     if (Reg >= X86::YMM16 && Reg <= X86::YMM31)
       return true;
-
     // Check for GPR with indexes between 16 - 31.
     if (X86II::isApxExtendedReg(Reg))
       return true;
-
     return false;
   };
 
@@ -139,10 +103,8 @@ static bool usesExtendedRegister(const MachineInstr &MI) {
       continue;
 
     Register Reg = MO.getReg();
-
-    assert(!(Reg >= X86::ZMM0 && Reg <= X86::ZMM31) &&
+    assert(!X86II::isZMMReg(Reg) &&
            "ZMM instructions should not be in the EVEX->VEX tables");
-
     if (isHiRegIdx(Reg))
       return true;
   }
@@ -150,21 +112,58 @@ static bool usesExtendedRegister(const MachineInstr &MI) {
   return false;
 }
 
+static bool checkVEXInstPredicate(unsigned EvexOpc, const X86Subtarget &ST) {
+  switch (EvexOpc) {
+  default:
+    return true;
+  case X86::VCVTNEPS2BF16Z128rm:
+  case X86::VCVTNEPS2BF16Z128rr:
+  case X86::VCVTNEPS2BF16Z256rm:
+  case X86::VCVTNEPS2BF16Z256rr:
+    return ST.hasAVXNECONVERT();
+  case X86::VPDPBUSDSZ128m:
+  case X86::VPDPBUSDSZ128r:
+  case X86::VPDPBUSDSZ256m:
+  case X86::VPDPBUSDSZ256r:
+  case X86::VPDPBUSDZ128m:
+  case X86::VPDPBUSDZ128r:
+  case X86::VPDPBUSDZ256m:
+  case X86::VPDPBUSDZ256r:
+  case X86::VPDPWSSDSZ128m:
+  case X86::VPDPWSSDSZ128r:
+  case X86::VPDPWSSDSZ256m:
+  case X86::VPDPWSSDSZ256r:
+  case X86::VPDPWSSDZ128m:
+  case X86::VPDPWSSDZ128r:
+  case X86::VPDPWSSDZ256m:
+  case X86::VPDPWSSDZ256r:
+    return ST.hasAVXVNNI();
+  case X86::VPMADD52HUQZ128m:
+  case X86::VPMADD52HUQZ128r:
+  case X86::VPMADD52HUQZ256m:
+  case X86::VPMADD52HUQZ256r:
+  case X86::VPMADD52LUQZ128m:
+  case X86::VPMADD52LUQZ128r:
+  case X86::VPMADD52LUQZ256m:
+  case X86::VPMADD52LUQZ256r:
+    return ST.hasAVXIFMA();
+  }
+}
+
 // Do any custom cleanup needed to finalize the conversion.
-static bool performCustomAdjustments(MachineInstr &MI, unsigned NewOpc,
-                                     const X86Subtarget *ST) {
-  (void)NewOpc;
+static bool performCustomAdjustments(MachineInstr &MI, unsigned VexOpc) {
+  (void)VexOpc;
   unsigned Opc = MI.getOpcode();
   switch (Opc) {
   case X86::VALIGNDZ128rri:
   case X86::VALIGNDZ128rmi:
   case X86::VALIGNQZ128rri:
   case X86::VALIGNQZ128rmi: {
-    assert((NewOpc == X86::VPALIGNRrri || NewOpc == X86::VPALIGNRrmi) &&
+    assert((VexOpc == X86::VPALIGNRrri || VexOpc == X86::VPALIGNRrmi) &&
            "Unexpected new opcode!");
-    unsigned Scale = (Opc == X86::VALIGNQZ128rri ||
-                      Opc == X86::VALIGNQZ128rmi) ? 8 : 4;
-    MachineOperand &Imm = MI.getOperand(MI.getNumExplicitOperands()-1);
+    unsigned Scale =
+        (Opc == X86::VALIGNQZ128rri || Opc == X86::VALIGNQZ128rmi) ? 8 : 4;
+    MachineOperand &Imm = MI.getOperand(MI.getNumExplicitOperands() - 1);
     Imm.setImm(Imm.getImm() * Scale);
     break;
   }
@@ -176,10 +175,10 @@ static bool performCustomAdjustments(MachineInstr &MI, unsigned NewOpc,
   case X86::VSHUFI32X4Z256rri:
   case X86::VSHUFI64X2Z256rmi:
   case X86::VSHUFI64X2Z256rri: {
-    assert((NewOpc == X86::VPERM2F128rr || NewOpc == X86::VPERM2I128rr ||
-            NewOpc == X86::VPERM2F128rm || NewOpc == X86::VPERM2I128rm) &&
+    assert((VexOpc == X86::VPERM2F128rr || VexOpc == X86::VPERM2I128rr ||
+            VexOpc == X86::VPERM2F128rm || VexOpc == X86::VPERM2I128rm) &&
            "Unexpected new opcode!");
-    MachineOperand &Imm = MI.getOperand(MI.getNumExplicitOperands()-1);
+    MachineOperand &Imm = MI.getOperand(MI.getNumExplicitOperands() - 1);
     int64_t ImmVal = Imm.getImm();
     // Set bit 5, move bit 1 to bit 4, copy bit 0.
     Imm.setImm(0x20 | ((ImmVal & 2) << 3) | (ImmVal & 1));
@@ -212,10 +211,9 @@ static bool performCustomAdjustments(MachineInstr &MI, unsigned NewOpc,
   return true;
 }
 
-
 // For EVEX instructions that can be encoded using VEX encoding
 // replace them by the VEX encoding in order to reduce size.
-bool EvexToVexInstPass::CompressEvexToVexImpl(MachineInstr &MI) const {
+static bool CompressEvexToVexImpl(MachineInstr &MI, const X86Subtarget &ST) {
   // VEX format.
   // # of bytes: 0,2,3  1      1      0,1   0,1,2,4  0,1
   //  [Prefixes] [VEX]  OPCODE ModR/M [SIB] [DISP]  [IMM]
@@ -223,7 +221,6 @@ bool EvexToVexInstPass::CompressEvexToVexImpl(MachineInstr &MI) const {
   // EVEX format.
   //  # of bytes: 4    1      1      1      4       / 1         1
   //  [Prefixes]  EVEX Opcode ModR/M [SIB] [Disp32] / [Disp8*N] [Immediate]
-
   const MCInstrDesc &Desc = MI.getDesc();
 
   // Check for EVEX instructions only.
@@ -241,6 +238,29 @@ bool EvexToVexInstPass::CompressEvexToVexImpl(MachineInstr &MI) const {
   if (Desc.TSFlags & X86II::EVEX_L2)
     return false;
 
+  // Use the VEX.L bit to select the 128 or 256-bit table.
+  ArrayRef<X86EvexToVexCompressTableEntry> Table =
+      (Desc.TSFlags & X86II::VEX_L) ? ArrayRef(X86EvexToVex256CompressTable)
+                                    : ArrayRef(X86EvexToVex128CompressTable);
+
+  unsigned EvexOpc = MI.getOpcode();
+  const auto *I = llvm::lower_bound(Table, EvexOpc);
+  if (I == Table.end() || I->EvexOpc != EvexOpc)
+    return false;
+
+  if (usesExtendedRegister(MI))
+    return false;
+  if (!checkVEXInstPredicate(EvexOpc, ST))
+    return false;
+  if (!performCustomAdjustments(MI, I->VexOpc))
+    return false;
+
+  MI.setDesc(ST.getInstrInfo()->get(I->VexOpc));
+  MI.setAsmPrinterFlag(X86::AC_EVEX_2_VEX);
+  return true;
+}
+
+bool EvexToVexInstPass::runOnMachineFunction(MachineFunction &MF) {
 #ifndef NDEBUG
   // Make sure the tables are sorted.
   static std::atomic<bool> TableChecked(false);
@@ -252,30 +272,21 @@ bool EvexToVexInstPass::CompressEvexToVexImpl(MachineInstr &MI) const {
     TableChecked.store(true, std::memory_order_relaxed);
   }
 #endif
-
-  // Use the VEX.L bit to select the 128 or 256-bit table.
-  ArrayRef<X86EvexToVexCompressTableEntry> Table =
-      (Desc.TSFlags & X86II::VEX_L) ? ArrayRef(X86EvexToVex256CompressTable)
-                                    : ArrayRef(X86EvexToVex128CompressTable);
-
-  const auto *I = llvm::lower_bound(Table, MI.getOpcode());
-  if (I == Table.end() || I->EvexOpcode != MI.getOpcode())
+  const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
+  if (!ST.hasAVX512())
     return false;
 
-  unsigned NewOpc = I->VexOpcode;
-
-  if (usesExtendedRegister(MI))
-    return false;
-
-  if (!CheckVEXInstPredicate(MI, ST))
-    return false;
+  bool Changed = false;
 
-  if (!performCustomAdjustments(MI, NewOpc, ST))
-    return false;
+  /// Go over all basic blocks in function and replace
+  /// EVEX encoded instrs by VEX encoding when possible.
+  for (MachineBasicBlock &MBB : MF) {
+    // Traverse the basic block.
+    for (MachineInstr &MI : MBB)
+      Changed |= CompressEvexToVexImpl(MI, ST);
+  }
 
-  MI.setDesc(TII->get(NewOpc));
-  MI.setAsmPrinterFlag(X86::AC_EVEX_2_VEX);
-  return true;
+  return Changed;
 }
 
 INITIALIZE_PASS(EvexToVexInstPass, EVEX2VEX_NAME, EVEX2VEX_DESC, false, false)

diff  --git a/llvm/lib/Target/X86/X86InstrFormats.td b/llvm/lib/Target/X86/X86InstrFormats.td
index 41d555d506598cd..68a9bb7053a1c97 100644
--- a/llvm/lib/Target/X86/X86InstrFormats.td
+++ b/llvm/lib/Target/X86/X86InstrFormats.td
@@ -371,8 +371,6 @@ class X86Inst<bits<8> opcod, Format f, ImmType i, dag outs, dag ins,
   bit notEVEX2VEXConvertible = 0; // Prevent EVEX->VEX conversion.
   ExplicitOpPrefix explicitOpPrefix = NoExplicitOpPrefix;
   bits<2> explicitOpPrefixBits = explicitOpPrefix.Value;
-  // Force to check predicate before compress EVEX to VEX encoding.
-  bit checkVEXPredicate = 0;
   // TSFlags layout should be kept in sync with X86BaseInfo.h.
   let TSFlags{6-0}   = FormBits;
   let TSFlags{8-7}   = OpSizeBits;

diff  --git a/llvm/lib/Target/X86/X86InstrSSE.td b/llvm/lib/Target/X86/X86InstrSSE.td
index add24a061765014..ef6db2d45d66124 100644
--- a/llvm/lib/Target/X86/X86InstrSSE.td
+++ b/llvm/lib/Target/X86/X86InstrSSE.td
@@ -7316,7 +7316,7 @@ defm VMASKMOVPD : avx_movmask_rm<0x2D, 0x2F, "vmaskmovpd",
 // AVX_VNNI
 //===----------------------------------------------------------------------===//
 let Predicates = [HasAVXVNNI, NoVLX_Or_NoVNNI], Constraints = "$src1 = $dst",
-    explicitOpPrefix = ExplicitVEX, checkVEXPredicate = 1 in
+    explicitOpPrefix = ExplicitVEX in
 multiclass avx_vnni_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
                        bit IsCommutable> {
   let isCommutable = IsCommutable in
@@ -8142,8 +8142,7 @@ let isCommutable = 0 in {
 }
 
 // AVX-IFMA
-let Predicates = [HasAVXIFMA, NoVLX_Or_NoIFMA], Constraints = "$src1 = $dst",
-    checkVEXPredicate = 1 in
+let Predicates = [HasAVXIFMA, NoVLX_Or_NoIFMA], Constraints = "$src1 = $dst" in
 multiclass avx_ifma_rm<bits<8> opc, string OpcodeStr, SDNode OpNode> {
   // NOTE: The SDNode have the multiply operands first with the add last.
   // This enables commuted load patterns to be autogenerated by tablegen.
@@ -8287,7 +8286,6 @@ let Predicates = [HasAVXNECONVERT] in {
        f256mem>, T8XD;
   defm VCVTNEOPH2PS : AVX_NE_CONVERT_BASE<0xb0, "vcvtneoph2ps", f128mem,
        f256mem>, T8PS;
-  let checkVEXPredicate = 1 in
   defm VCVTNEPS2BF16 : VCVTNEPS2BF16_BASE, VEX, T8XS, ExplicitVEXPrefix;
 
   def : Pat<(v8bf16 (X86vfpround (v8f32 VR256:$src))),

diff  --git a/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp b/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp
index 9871cf62cc0ad68..c80d9a199fa3c19 100644
--- a/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp
+++ b/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp
@@ -33,14 +33,12 @@ class X86EVEX2VEXTablesEmitter {
   // to make the search more efficient
   std::map<uint64_t, std::vector<const CodeGenInstruction *>> VEXInsts;
 
-  typedef std::pair<const CodeGenInstruction *, const CodeGenInstruction *> Entry;
-  typedef std::pair<StringRef, StringRef> Predicate;
+  typedef std::pair<const CodeGenInstruction *, const CodeGenInstruction *>
+      Entry;
 
   // Represent both compress tables
   std::vector<Entry> EVEX2VEX128;
   std::vector<Entry> EVEX2VEX256;
-  // Represent predicates of VEX instructions.
-  std::vector<Predicate> EVEX2VEXPredicates;
 
 public:
   X86EVEX2VEXTablesEmitter(RecordKeeper &R) : Records(R), Target(R) {}
@@ -52,9 +50,6 @@ class X86EVEX2VEXTablesEmitter {
   // Prints the given table as a C++ array of type
   // X86EvexToVexCompressTableEntry
   void printTable(const std::vector<Entry> &Table, raw_ostream &OS);
-  // Prints function which checks target feature specific predicate.
-  void printCheckPredicate(const std::vector<Predicate> &Predicates,
-                           raw_ostream &OS);
 };
 
 void X86EVEX2VEXTablesEmitter::printTable(const std::vector<Entry> &Table,
@@ -77,19 +72,6 @@ void X86EVEX2VEXTablesEmitter::printTable(const std::vector<Entry> &Table,
   OS << "};\n\n";
 }
 
-void X86EVEX2VEXTablesEmitter::printCheckPredicate(
-    const std::vector<Predicate> &Predicates, raw_ostream &OS) {
-  OS << "static bool CheckVEXInstPredicate"
-     << "(MachineInstr &MI, const X86Subtarget *Subtarget) {\n"
-     << "  unsigned Opc = MI.getOpcode();\n"
-     << "  switch (Opc) {\n"
-     << "    default: return true;\n";
-  for (const auto &Pair : Predicates)
-    OS << "    case X86::" << Pair.first << ": return " << Pair.second << ";\n";
-  OS << "  }\n"
-     << "}\n\n";
-}
-
 // Return true if the 2 BitsInits are equal
 // Calculates the integer value residing BitsInit object
 static inline uint64_t getValueFromBitsInit(const BitsInit *B) {
@@ -164,18 +146,6 @@ class IsMatch {
 };
 
 void X86EVEX2VEXTablesEmitter::run(raw_ostream &OS) {
-  auto getPredicates = [&](const CodeGenInstruction *Inst) {
-    std::vector<Record *> PredicatesRecords =
-        Inst->TheDef->getValueAsListOfDefs("Predicates");
-    // Currently we only do AVX related checks and assume each instruction
-    // has one and only one AVX related predicates.
-    for (unsigned i = 0, e = PredicatesRecords.size(); i != e; ++i)
-      if (PredicatesRecords[i]->getName().starts_with("HasAVX"))
-        return PredicatesRecords[i]->getValueAsString("CondString");
-    llvm_unreachable(
-        "Instruction with checkPredicate set must have one predicate!");
-  };
-
   emitSourceFileHeader("X86 EVEX2VEX tables", OS);
 
   ArrayRef<const CodeGenInstruction *> NumberedInstructions =
@@ -228,18 +198,11 @@ void X86EVEX2VEXTablesEmitter::run(raw_ostream &OS) {
       EVEX2VEX256.push_back(std::make_pair(EVEXInst, VEXInst)); // {0,1}
     else
       EVEX2VEX128.push_back(std::make_pair(EVEXInst, VEXInst)); // {0,0}
-
-    // Adding predicate check to EVEX2VEXPredicates table when needed.
-    if (VEXInst->TheDef->getValueAsBit("checkVEXPredicate"))
-      EVEX2VEXPredicates.push_back(
-          std::make_pair(EVEXInst->TheDef->getName(), getPredicates(VEXInst)));
   }
 
   // Print both tables
   printTable(EVEX2VEX128, OS);
   printTable(EVEX2VEX256, OS);
-  // Print CheckVEXInstPredicate function.
-  printCheckPredicate(EVEX2VEXPredicates, OS);
 }
 } // namespace
 


        


More information about the llvm-commits mailing list