[llvm] [X86][CodeGen] Support folding memory broadcast in X86InstrInfo::foldMemoryOperandImpl (PR #79761)

Shengchen Kan via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 29 09:16:50 PST 2024


https://github.com/KanRobert updated https://github.com/llvm/llvm-project/pull/79761

>From b11ae150cb905bdf163cf45ec01fd9c9394adb38 Mon Sep 17 00:00:00 2001
From: Shengchen Kan <shengchen.kan at intel.com>
Date: Sun, 28 Jan 2024 12:00:19 +0800
Subject: [PATCH] [X86][CodeGen] Support folding memory broadcast in
 X86InstrInfo::foldMemoryOperandImpl

---
 llvm/lib/Target/X86/X86InstrAVX512.td      |   2 +-
 llvm/lib/Target/X86/X86InstrFoldTables.cpp |  21 +++-
 llvm/lib/Target/X86/X86InstrFoldTables.h   |   6 +
 llvm/lib/Target/X86/X86InstrInfo.cpp       | 134 +++++++++++++++++++++
 llvm/lib/Target/X86/X86InstrInfo.h         |   7 ++
 5 files changed, 167 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index bb5e22c7142793..b588f660e2744e 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -1067,7 +1067,7 @@ multiclass avx512_broadcast_rm_split<bits<8> opc, string OpcodeStr,
                         MaskInfo.RC:$src0))],
                       DestInfo.ExeDomain>, T8, PD, EVEX, EVEX_K, Sched<[SchedRR]>;
 
-  let hasSideEffects = 0, mayLoad = 1 in
+  let hasSideEffects = 0, mayLoad = 1, isReMaterializable = 1, canFoldAsLoad = 1 in
   def rm : AVX512PI<opc, MRMSrcMem, (outs MaskInfo.RC:$dst),
                     (ins SrcInfo.ScalarMemOp:$src),
                     !strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),
diff --git a/llvm/lib/Target/X86/X86InstrFoldTables.cpp b/llvm/lib/Target/X86/X86InstrFoldTables.cpp
index 1d6df0f6ad129f..cfea8acecc0260 100644
--- a/llvm/lib/Target/X86/X86InstrFoldTables.cpp
+++ b/llvm/lib/Target/X86/X86InstrFoldTables.cpp
@@ -143,6 +143,23 @@ const X86FoldTableEntry *llvm::lookupFoldTable(unsigned RegOp, unsigned OpNum) {
   return lookupFoldTableImpl(FoldTable, RegOp);
 }
 
+const X86FoldTableEntry *
+llvm::lookupBroadcastFoldTable(unsigned RegOp, unsigned OpNum) {
+  ArrayRef<X86FoldTableEntry> FoldTable;
+  if (OpNum == 1)
+    FoldTable = ArrayRef(BroadcastTable1);
+  else if (OpNum == 2)
+    FoldTable = ArrayRef(BroadcastTable2);
+  else if (OpNum == 3)
+    FoldTable = ArrayRef(BroadcastTable3);
+  else if (OpNum == 4)
+    FoldTable = ArrayRef(BroadcastTable4);
+  else
+    return nullptr;
+
+  return lookupFoldTableImpl(FoldTable, RegOp);
+}
+
 namespace {
 
 // This class stores the memory unfolding tables. It is instantiated as a
@@ -288,8 +305,8 @@ struct X86BroadcastFoldTable {
 };
 } // namespace
 
-static bool matchBroadcastSize(const X86FoldTableEntry &Entry,
-                               unsigned BroadcastBits) {
+bool llvm::matchBroadcastSize(const X86FoldTableEntry &Entry,
+                              unsigned BroadcastBits) {
   switch (Entry.Flags & TB_BCAST_MASK) {
   case TB_BCAST_W:
   case TB_BCAST_SH:
diff --git a/llvm/lib/Target/X86/X86InstrFoldTables.h b/llvm/lib/Target/X86/X86InstrFoldTables.h
index a27d868537cbcf..9c5dea48d22733 100644
--- a/llvm/lib/Target/X86/X86InstrFoldTables.h
+++ b/llvm/lib/Target/X86/X86InstrFoldTables.h
@@ -44,6 +44,11 @@ const X86FoldTableEntry *lookupTwoAddrFoldTable(unsigned RegOp);
 // operand OpNum.
 const X86FoldTableEntry *lookupFoldTable(unsigned RegOp, unsigned OpNum);
 
+// Look up the broadcast folding table entry for folding a broadcast with
+// operand OpNum.
+const X86FoldTableEntry *lookupBroadcastFoldTable(unsigned RegOp,
+                                                  unsigned OpNum);
+
 // Look up the memory unfolding table entry for this instruction.
 const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
 
@@ -52,6 +57,7 @@ const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
 const X86FoldTableEntry *lookupBroadcastFoldTableBySize(unsigned MemOp,
                                                         unsigned BroadcastBits);
 
+bool matchBroadcastSize(const X86FoldTableEntry &Entry, unsigned BroadcastBits);
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp
index 1adaf097bcbf73..9ef680f6768782 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.cpp
+++ b/llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -862,6 +862,28 @@ bool X86InstrInfo::isReallyTriviallyReMaterializable(
   case X86::MMX_MOVD64rm:
   case X86::MMX_MOVQ64rm:
   // AVX-512
+  case X86::VPBROADCASTBZ128rm:
+  case X86::VPBROADCASTBZ256rm:
+  case X86::VPBROADCASTBZrm:
+  case X86::VBROADCASTF32X2Z256rm:
+  case X86::VBROADCASTF32X2Zrm:
+  case X86::VBROADCASTI32X2Z128rm:
+  case X86::VBROADCASTI32X2Z256rm:
+  case X86::VBROADCASTI32X2Zrm:
+  case X86::VPBROADCASTWZ128rm:
+  case X86::VPBROADCASTWZ256rm:
+  case X86::VPBROADCASTWZrm:
+  case X86::VPBROADCASTDZ128rm:
+  case X86::VPBROADCASTDZ256rm:
+  case X86::VPBROADCASTDZrm:
+  case X86::VBROADCASTSSZ128rm:
+  case X86::VBROADCASTSSZ256rm:
+  case X86::VBROADCASTSSZrm:
+  case X86::VPBROADCASTQZ128rm:
+  case X86::VPBROADCASTQZ256rm:
+  case X86::VPBROADCASTQZrm:
+  case X86::VBROADCASTSDZ256rm:
+  case X86::VBROADCASTSDZrm:
   case X86::VMOVSSZrm:
   case X86::VMOVSSZrm_alt:
   case X86::VMOVSDZrm:
@@ -8063,6 +8085,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
     MOs.push_back(MachineOperand::CreateReg(0, false));
     break;
   }
+  case X86::VPBROADCASTBZ128rm:
+  case X86::VPBROADCASTBZ256rm:
+  case X86::VPBROADCASTBZrm:
+  case X86::VBROADCASTF32X2Z256rm:
+  case X86::VBROADCASTF32X2Zrm:
+  case X86::VBROADCASTI32X2Z128rm:
+  case X86::VBROADCASTI32X2Z256rm:
+  case X86::VBROADCASTI32X2Zrm:
+    // No instructions currently fuse with 8bits or 32bits x 2.
+    return nullptr;
+
+#define FOLD_BROADCAST(SIZE)                                                   \
+  MOs.append(LoadMI.operands_begin() + NumOps - X86::AddrNumOperands,          \
+             LoadMI.operands_begin() + NumOps);                                \
+  return foldMemoryBroadcast(MF, MI, Ops[0], MOs, InsertPt, /*Size=*/SIZE,     \
+                             Alignment, /*AllowCommute=*/true);
+  case X86::VPBROADCASTWZ128rm:
+  case X86::VPBROADCASTWZ256rm:
+  case X86::VPBROADCASTWZrm:
+    FOLD_BROADCAST(16);
+  case X86::VPBROADCASTDZ128rm:
+  case X86::VPBROADCASTDZ256rm:
+  case X86::VPBROADCASTDZrm:
+  case X86::VBROADCASTSSZ128rm:
+  case X86::VBROADCASTSSZ256rm:
+  case X86::VBROADCASTSSZrm:
+    FOLD_BROADCAST(32);
+  case X86::VPBROADCASTQZ128rm:
+  case X86::VPBROADCASTQZ256rm:
+  case X86::VPBROADCASTQZrm:
+  case X86::VBROADCASTSDZ256rm:
+  case X86::VBROADCASTSDZrm:
+    FOLD_BROADCAST(64);
   default: {
     if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))
       return nullptr;
@@ -8077,6 +8132,78 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
                                /*Size=*/0, Alignment, /*AllowCommute=*/true);
 }
 
+MachineInstr *X86InstrInfo::foldMemoryBroadcast(
+    MachineFunction &MF, MachineInstr &MI, unsigned OpNum,
+    ArrayRef<MachineOperand> MOs, MachineBasicBlock::iterator InsertPt,
+    unsigned BitsSize, Align Alignment, bool AllowCommute) const {
+
+  if (auto *I = lookupBroadcastFoldTable(MI.getOpcode(), OpNum))
+    return matchBroadcastSize(*I, BitsSize)
+               ? FuseInst(MF, I->DstOp, OpNum, MOs, InsertPt, MI, *this)
+               : nullptr;
+
+  // TODO: Share code with foldMemoryOperandImpl for the commute
+  if (AllowCommute) {
+    unsigned CommuteOpIdx1 = OpNum, CommuteOpIdx2 = CommuteAnyOperandIndex;
+    if (findCommutedOpIndices(MI, CommuteOpIdx1, CommuteOpIdx2)) {
+      bool HasDef = MI.getDesc().getNumDefs();
+      Register Reg0 = HasDef ? MI.getOperand(0).getReg() : Register();
+      Register Reg1 = MI.getOperand(CommuteOpIdx1).getReg();
+      Register Reg2 = MI.getOperand(CommuteOpIdx2).getReg();
+      bool Tied1 =
+          0 == MI.getDesc().getOperandConstraint(CommuteOpIdx1, MCOI::TIED_TO);
+      bool Tied2 =
+          0 == MI.getDesc().getOperandConstraint(CommuteOpIdx2, MCOI::TIED_TO);
+
+      // If either of the commutable operands are tied to the destination
+      // then we can not commute + fold.
+      if ((HasDef && Reg0 == Reg1 && Tied1) ||
+          (HasDef && Reg0 == Reg2 && Tied2))
+        return nullptr;
+
+      MachineInstr *CommutedMI =
+          commuteInstruction(MI, false, CommuteOpIdx1, CommuteOpIdx2);
+      if (!CommutedMI) {
+        // Unable to commute.
+        return nullptr;
+      }
+      if (CommutedMI != &MI) {
+        // New instruction. We can't fold from this.
+        CommutedMI->eraseFromParent();
+        return nullptr;
+      }
+
+      // Attempt to fold with the commuted version of the instruction.
+      MachineInstr *NewMI = foldMemoryBroadcast(MF, MI, CommuteOpIdx2, MOs,
+                                                InsertPt, BitsSize, Alignment,
+                                                /*AllowCommute=*/false);
+      if (NewMI)
+        return NewMI;
+
+      // Folding failed again - undo the commute before returning.
+      MachineInstr *UncommutedMI =
+          commuteInstruction(MI, false, CommuteOpIdx1, CommuteOpIdx2);
+      if (!UncommutedMI) {
+        // Unable to commute.
+        return nullptr;
+      }
+      if (UncommutedMI != &MI) {
+        // New instruction. It doesn't need to be kept.
+        UncommutedMI->eraseFromParent();
+        return nullptr;
+      }
+
+      // Return here to prevent duplicate fuse failure report.
+      return nullptr;
+    }
+  }
+
+  // No fusion
+  if (PrintFailedFusing && !MI.isCopy())
+    dbgs() << "We failed to fuse operand " << OpNum << " in " << MI;
+  return nullptr;
+}
+
 static SmallVector<MachineMemOperand *, 2>
 extractLoadMMOs(ArrayRef<MachineMemOperand *> MMOs, MachineFunction &MF) {
   SmallVector<MachineMemOperand *, 2> LoadMMOs;
@@ -8167,6 +8294,10 @@ bool X86InstrInfo::unfoldMemoryOperand(
   unsigned Index = I->Flags & TB_INDEX_MASK;
   bool FoldedLoad = I->Flags & TB_FOLDED_LOAD;
   bool FoldedStore = I->Flags & TB_FOLDED_STORE;
+  // FIXME: Support TB_BCAST_SH in getBroadcastOpcode?
+  if ((I->Flags & TB_BCAST_MASK) == TB_BCAST_SH)
+    return false;
+
   if (UnfoldLoad && !FoldedLoad)
     return false;
   UnfoldLoad &= FoldedLoad;
@@ -8316,6 +8447,9 @@ bool X86InstrInfo::unfoldMemoryOperand(
   unsigned Index = I->Flags & TB_INDEX_MASK;
   bool FoldedLoad = I->Flags & TB_FOLDED_LOAD;
   bool FoldedStore = I->Flags & TB_FOLDED_STORE;
+  // FIXME: Support TB_BCAST_SH in getBroadcastOpcode?
+  if ((I->Flags & TB_BCAST_MASK) == TB_BCAST_SH)
+    return false;
   const MCInstrDesc &MCID = get(Opc);
   MachineFunction &MF = DAG.getMachineFunction();
   const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h
index 0cb69050656109..3a1f98a005ca3a 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.h
+++ b/llvm/lib/Target/X86/X86InstrInfo.h
@@ -643,6 +643,13 @@ class X86InstrInfo final : public X86GenInstrInfo {
                                         MachineBasicBlock::iterator InsertPt,
                                         unsigned Size, Align Alignment) const;
 
+  MachineInstr *foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
+                                    unsigned OpNum,
+                                    ArrayRef<MachineOperand> MOs,
+                                    MachineBasicBlock::iterator InsertPt,
+                                    unsigned BitsSize, Align Alignment,
+                                    bool AllowCommute) const;
+
   /// isFrameOperand - Return true and the FrameIndex if the specified
   /// operand and follow operands form a reference to the stack frame.
   bool isFrameOperand(const MachineInstr &MI, unsigned int Op,



More information about the llvm-commits mailing list