[llvm] [WIP][X86][tablgen] Auto-gen broadcast tables (PR #73654)

Shengchen Kan via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 28 06:57:42 PST 2023


https://github.com/KanRobert created https://github.com/llvm/llvm-project/pull/73654

None

>From d412c141417f1d0434b4e2627aca0d82e6255106 Mon Sep 17 00:00:00 2001
From: Shengchen Kan <shengchen.kan at intel.com>
Date: Tue, 28 Nov 2023 22:56:08 +0800
Subject: [PATCH] [WIP][X86][tablgen] Auto-gen broadcast tables

---
 llvm/lib/Target/X86/X86InstrFoldTables.cpp   |  3 +-
 llvm/utils/TableGen/X86FoldTablesEmitter.cpp | 80 +++++++++++++++-----
 2 files changed, 63 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Target/X86/X86InstrFoldTables.cpp b/llvm/lib/Target/X86/X86InstrFoldTables.cpp
index 7a3611b90da9895..cca2d03f1356ea9 100644
--- a/llvm/lib/Target/X86/X86InstrFoldTables.cpp
+++ b/llvm/lib/Target/X86/X86InstrFoldTables.cpp
@@ -23,6 +23,7 @@ using namespace llvm;
 // are currently emitted in X86GenInstrInfo.inc in alphabetical order. Which
 // makes sorting these tables a simple matter of alphabetizing the table.
 #include "X86GenFoldTables.inc"
+/*
 static const X86FoldTableEntry BroadcastTable2[] = {
   { X86::VADDPDZ128rr,   X86::VADDPDZ128rmb,   TB_BCAST_SD },
   { X86::VADDPDZ256rr,   X86::VADDPDZ256rmb,   TB_BCAST_SD },
@@ -316,7 +317,7 @@ static const X86FoldTableEntry BroadcastTable3[] = {
   { X86::VPTERNLOGQZ256rri,    X86::VPTERNLOGQZ256rmbi,   TB_BCAST_Q },
   { X86::VPTERNLOGQZrri,       X86::VPTERNLOGQZrmbi,      TB_BCAST_Q },
 };
-
+*/
 // Table to map instructions safe to broadcast using a different width from the
 // element width.
 static const X86FoldTableEntry BroadcastSizeTable2[] = {
diff --git a/llvm/utils/TableGen/X86FoldTablesEmitter.cpp b/llvm/utils/TableGen/X86FoldTablesEmitter.cpp
index adcf67e8c3cc538..010240b18adac6a 100644
--- a/llvm/utils/TableGen/X86FoldTablesEmitter.cpp
+++ b/llvm/utils/TableGen/X86FoldTablesEmitter.cpp
@@ -153,6 +153,11 @@ class X86FoldTablesEmitter {
   FoldTable Table2;
   FoldTable Table3;
   FoldTable Table4;
+  FoldTable BroadcastTable0;
+  FoldTable BroadcastTable1;
+  FoldTable BroadcastTable2;
+  FoldTable BroadcastTable3;
+  FoldTable BroadcastTable4;
 
 public:
   X86FoldTablesEmitter(RecordKeeper &R) : Records(R), Target(R) {}
@@ -165,7 +170,7 @@ class X86FoldTablesEmitter {
   // S sets the strategy of adding the TB_NO_REVERSE flag.
   void updateTables(const CodeGenInstruction *RegInst,
                     const CodeGenInstruction *MemInst, uint16_t S = 0,
-                    bool IsManual = false);
+                    bool IsManual = false, bool IsBroadcast = false);
 
   // Generates X86FoldTableEntry with the given instructions and fill it with
   // the appropriate flags - then adds it to Table.
@@ -288,11 +293,12 @@ static bool isNOREXRegClass(const Record *Op) {
 class IsMatch {
   const CodeGenInstruction *MemInst;
   const X86Disassembler::RecognizableInstrBase MemRI;
+  bool IsBroadcast;
   const unsigned Variant;
 
 public:
-  IsMatch(const CodeGenInstruction *Inst, unsigned V)
-      : MemInst(Inst), MemRI(*MemInst), Variant(V) {}
+  IsMatch(const CodeGenInstruction *Inst, bool IsBroadcast, unsigned V)
+      : MemInst(Inst), MemRI(*MemInst), IsBroadcast(IsBroadcast), Variant(V) {}
 
   bool operator()(const CodeGenInstruction *RegInst) {
     X86Disassembler::RecognizableInstrBase RegRI(*RegInst);
@@ -300,8 +306,13 @@ class IsMatch {
     const Record *MemRec = MemInst->TheDef;
 
     // EVEX_B means different things for memory and register forms.
-    if (RegRI.HasEVEX_B || MemRI.HasEVEX_B)
-      return false;
+    if (IsBroadcast) {
+      if (RegRI.HasEVEX_B || !MemRI.HasEVEX_B)
+        return false;
+    } else {
+      if (RegRI.HasEVEX_B || MemRI.HasEVEX_B)
+        return false;
+    }
 
     if (!mayFoldFromLeftToRight(RegRI.Form, MemRI.Form))
       return false;
@@ -474,7 +485,8 @@ void X86FoldTablesEmitter::addEntryWithFlags(FoldTable &Table,
 
 void X86FoldTablesEmitter::updateTables(const CodeGenInstruction *RegInst,
                                         const CodeGenInstruction *MemInst,
-                                        uint16_t S, bool IsManual) {
+                                        uint16_t S, bool IsManual,
+                                        bool IsBroadcast) {
 
   Record *RegRec = RegInst->TheDef;
   Record *MemRec = MemInst->TheDef;
@@ -505,19 +517,24 @@ void X86FoldTablesEmitter::updateTables(const CodeGenInstruction *RegInst,
           isMemoryOperand(MemOpRec)) {
         switch (I) {
         case 0:
-          addEntryWithFlags(Table0, RegInst, MemInst, S, 0, IsManual);
+          addEntryWithFlags(IsBroadcast ? BroadcastTable0 : Table0, RegInst,
+                            MemInst, S, 0, IsManual);
           return;
         case 1:
-          addEntryWithFlags(Table1, RegInst, MemInst, S, 1, IsManual);
+          addEntryWithFlags(IsBroadcast ? BroadcastTable1 : Table1, RegInst,
+                            MemInst, S, 1, IsManual);
           return;
         case 2:
-          addEntryWithFlags(Table2, RegInst, MemInst, S, 2, IsManual);
+          addEntryWithFlags(IsBroadcast ? BroadcastTable2 : Table2, RegInst,
+                            MemInst, S, 2, IsManual);
           return;
         case 3:
-          addEntryWithFlags(Table3, RegInst, MemInst, S, 3, IsManual);
+          addEntryWithFlags(IsBroadcast ? BroadcastTable3 : Table3, RegInst,
+                            MemInst, S, 3, IsManual);
           return;
         case 4:
-          addEntryWithFlags(Table4, RegInst, MemInst, S, 4, IsManual);
+          addEntryWithFlags(IsBroadcast ? BroadcastTable4 : Table4, RegInst,
+                            MemInst, S, 4, IsManual);
           return;
         }
       }
@@ -580,8 +597,19 @@ void X86FoldTablesEmitter::run(raw_ostream &O) {
     }
   }
 
+  // Create a copy b/c the register instruction will removed when a new entry is
+  // added into memory fold tables.
+  auto RegInstsForBroadcast = RegInsts;
+
   Record *AsmWriter = Target.getAsmWriter();
   unsigned Variant = AsmWriter->getValueAsInt("Variant");
+  auto FixUp = [&](const CodeGenInstruction *RegInst) {
+    StringRef RegInstName = RegInst->TheDef->getName();
+    if (RegInstName.ends_with("_REV") || RegInstName.ends_with("_alt"))
+      if (auto *RegAltRec = Records.getDef(RegInstName.drop_back(4)))
+        RegInst = &Target.getInstruction(RegAltRec);
+    return RegInst;
+  };
   // For each memory form instruction, try to find its register form
   // instruction.
   for (const CodeGenInstruction *MemInst : MemInsts) {
@@ -596,17 +624,26 @@ void X86FoldTablesEmitter::run(raw_ostream &O) {
     // opcode.
     std::vector<const CodeGenInstruction *> &OpcRegInsts = RegInstsIt->second;
 
-    auto Match = find_if(OpcRegInsts, IsMatch(MemInst, Variant));
+    // Memory fold tables
+    auto Match = find_if(OpcRegInsts, IsMatch(MemInst, /*IsBroadcast=*/false, Variant));
     if (Match != OpcRegInsts.end()) {
-      const CodeGenInstruction *RegInst = *Match;
-      StringRef RegInstName = RegInst->TheDef->getName();
-      if (RegInstName.ends_with("_REV") || RegInstName.ends_with("_alt"))
-        if (auto *RegAltRec = Records.getDef(RegInstName.drop_back(4)))
-          RegInst = &Target.getInstruction(RegAltRec);
-
-      updateTables(RegInst, MemInst);
+      updateTables(FixUp(*Match), MemInst);
       OpcRegInsts.erase(Match);
     }
+
+    // Broadcast tables
+    StringRef MemInstName = MemInst->TheDef->getName();
+    if (MemInstName.contains("mb"))
+      continue;
+    RegInstsIt = RegInstsForBroadcast.find(Opc);
+    assert(RegInstsIt != RegInstsForBroadcast.end() && "Unexpected control flow");
+    std::vector<const CodeGenInstruction *> & OpcRegInstsForBroadcast = RegInstsIt->second;
+    Match = find_if(OpcRegInstsForBroadcast, IsMatch(MemInst, /*IsBroadcast=*/true, Variant));
+    if (Match != OpcRegInstsForBroadcast.end()) {
+      updateTables(FixUp(*Match), MemInst, 0, /*IsMannual=*/false,
+                   /*IsBroadcast=*/true);
+      OpcRegInstsForBroadcast.erase(Match);
+    }
   }
 
   // Add the manually mapped instructions listed above.
@@ -640,6 +677,11 @@ void X86FoldTablesEmitter::run(raw_ostream &O) {
   PRINT_TABLE(Table2)
   PRINT_TABLE(Table3)
   PRINT_TABLE(Table4)
+  //PRINT_TABLE(BroadcastTable0)
+  //PRINT_TABLE(BroadcastTable1)
+  PRINT_TABLE(BroadcastTable2)
+  PRINT_TABLE(BroadcastTable3)
+  //PRINT_TABLE(BroadcastTable4)
 }
 
 static TableGen::Emitter::OptClass<X86FoldTablesEmitter>



More information about the llvm-commits mailing list