[llvm] [X86][CodeGen] Support folding memory broadcast in X86InstrInfo::foldMemoryOperandImpl (PR #79761)
Shengchen Kan via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 29 09:16:50 PST 2024
https://github.com/KanRobert updated https://github.com/llvm/llvm-project/pull/79761
>From b11ae150cb905bdf163cf45ec01fd9c9394adb38 Mon Sep 17 00:00:00 2001
From: Shengchen Kan <shengchen.kan at intel.com>
Date: Sun, 28 Jan 2024 12:00:19 +0800
Subject: [PATCH] [X86][CodeGen] Support folding memory broadcast in
X86InstrInfo::foldMemoryOperandImpl
---
llvm/lib/Target/X86/X86InstrAVX512.td | 2 +-
llvm/lib/Target/X86/X86InstrFoldTables.cpp | 21 +++-
llvm/lib/Target/X86/X86InstrFoldTables.h | 6 +
llvm/lib/Target/X86/X86InstrInfo.cpp | 134 +++++++++++++++++++++
llvm/lib/Target/X86/X86InstrInfo.h | 7 ++
5 files changed, 167 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index bb5e22c7142793..b588f660e2744e 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -1067,7 +1067,7 @@ multiclass avx512_broadcast_rm_split<bits<8> opc, string OpcodeStr,
MaskInfo.RC:$src0))],
DestInfo.ExeDomain>, T8, PD, EVEX, EVEX_K, Sched<[SchedRR]>;
- let hasSideEffects = 0, mayLoad = 1 in
+ let hasSideEffects = 0, mayLoad = 1, isReMaterializable = 1, canFoldAsLoad = 1 in
def rm : AVX512PI<opc, MRMSrcMem, (outs MaskInfo.RC:$dst),
(ins SrcInfo.ScalarMemOp:$src),
!strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),
diff --git a/llvm/lib/Target/X86/X86InstrFoldTables.cpp b/llvm/lib/Target/X86/X86InstrFoldTables.cpp
index 1d6df0f6ad129f..cfea8acecc0260 100644
--- a/llvm/lib/Target/X86/X86InstrFoldTables.cpp
+++ b/llvm/lib/Target/X86/X86InstrFoldTables.cpp
@@ -143,6 +143,23 @@ const X86FoldTableEntry *llvm::lookupFoldTable(unsigned RegOp, unsigned OpNum) {
return lookupFoldTableImpl(FoldTable, RegOp);
}
+const X86FoldTableEntry *
+llvm::lookupBroadcastFoldTable(unsigned RegOp, unsigned OpNum) {
+ ArrayRef<X86FoldTableEntry> FoldTable;
+ if (OpNum == 1)
+ FoldTable = ArrayRef(BroadcastTable1);
+ else if (OpNum == 2)
+ FoldTable = ArrayRef(BroadcastTable2);
+ else if (OpNum == 3)
+ FoldTable = ArrayRef(BroadcastTable3);
+ else if (OpNum == 4)
+ FoldTable = ArrayRef(BroadcastTable4);
+ else
+ return nullptr;
+
+ return lookupFoldTableImpl(FoldTable, RegOp);
+}
+
namespace {
// This class stores the memory unfolding tables. It is instantiated as a
@@ -288,8 +305,8 @@ struct X86BroadcastFoldTable {
};
} // namespace
-static bool matchBroadcastSize(const X86FoldTableEntry &Entry,
- unsigned BroadcastBits) {
+bool llvm::matchBroadcastSize(const X86FoldTableEntry &Entry,
+ unsigned BroadcastBits) {
switch (Entry.Flags & TB_BCAST_MASK) {
case TB_BCAST_W:
case TB_BCAST_SH:
diff --git a/llvm/lib/Target/X86/X86InstrFoldTables.h b/llvm/lib/Target/X86/X86InstrFoldTables.h
index a27d868537cbcf..9c5dea48d22733 100644
--- a/llvm/lib/Target/X86/X86InstrFoldTables.h
+++ b/llvm/lib/Target/X86/X86InstrFoldTables.h
@@ -44,6 +44,11 @@ const X86FoldTableEntry *lookupTwoAddrFoldTable(unsigned RegOp);
// operand OpNum.
const X86FoldTableEntry *lookupFoldTable(unsigned RegOp, unsigned OpNum);
+// Look up the broadcast folding table entry for folding a broadcast with
+// operand OpNum.
+const X86FoldTableEntry *lookupBroadcastFoldTable(unsigned RegOp,
+ unsigned OpNum);
+
// Look up the memory unfolding table entry for this instruction.
const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
@@ -52,6 +57,7 @@ const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
const X86FoldTableEntry *lookupBroadcastFoldTableBySize(unsigned MemOp,
unsigned BroadcastBits);
+bool matchBroadcastSize(const X86FoldTableEntry &Entry, unsigned BroadcastBits);
} // namespace llvm
#endif
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp
index 1adaf097bcbf73..9ef680f6768782 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.cpp
+++ b/llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -862,6 +862,28 @@ bool X86InstrInfo::isReallyTriviallyReMaterializable(
case X86::MMX_MOVD64rm:
case X86::MMX_MOVQ64rm:
// AVX-512
+ case X86::VPBROADCASTBZ128rm:
+ case X86::VPBROADCASTBZ256rm:
+ case X86::VPBROADCASTBZrm:
+ case X86::VBROADCASTF32X2Z256rm:
+ case X86::VBROADCASTF32X2Zrm:
+ case X86::VBROADCASTI32X2Z128rm:
+ case X86::VBROADCASTI32X2Z256rm:
+ case X86::VBROADCASTI32X2Zrm:
+ case X86::VPBROADCASTWZ128rm:
+ case X86::VPBROADCASTWZ256rm:
+ case X86::VPBROADCASTWZrm:
+ case X86::VPBROADCASTDZ128rm:
+ case X86::VPBROADCASTDZ256rm:
+ case X86::VPBROADCASTDZrm:
+ case X86::VBROADCASTSSZ128rm:
+ case X86::VBROADCASTSSZ256rm:
+ case X86::VBROADCASTSSZrm:
+ case X86::VPBROADCASTQZ128rm:
+ case X86::VPBROADCASTQZ256rm:
+ case X86::VPBROADCASTQZrm:
+ case X86::VBROADCASTSDZ256rm:
+ case X86::VBROADCASTSDZrm:
case X86::VMOVSSZrm:
case X86::VMOVSSZrm_alt:
case X86::VMOVSDZrm:
@@ -8063,6 +8085,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
MOs.push_back(MachineOperand::CreateReg(0, false));
break;
}
+ case X86::VPBROADCASTBZ128rm:
+ case X86::VPBROADCASTBZ256rm:
+ case X86::VPBROADCASTBZrm:
+ case X86::VBROADCASTF32X2Z256rm:
+ case X86::VBROADCASTF32X2Zrm:
+ case X86::VBROADCASTI32X2Z128rm:
+ case X86::VBROADCASTI32X2Z256rm:
+ case X86::VBROADCASTI32X2Zrm:
+ // No instructions currently fuse with 8bits or 32bits x 2.
+ return nullptr;
+
+#define FOLD_BROADCAST(SIZE) \
+ MOs.append(LoadMI.operands_begin() + NumOps - X86::AddrNumOperands, \
+ LoadMI.operands_begin() + NumOps); \
+ return foldMemoryBroadcast(MF, MI, Ops[0], MOs, InsertPt, /*Size=*/SIZE, \
+ Alignment, /*AllowCommute=*/true);
+ case X86::VPBROADCASTWZ128rm:
+ case X86::VPBROADCASTWZ256rm:
+ case X86::VPBROADCASTWZrm:
+ FOLD_BROADCAST(16);
+ case X86::VPBROADCASTDZ128rm:
+ case X86::VPBROADCASTDZ256rm:
+ case X86::VPBROADCASTDZrm:
+ case X86::VBROADCASTSSZ128rm:
+ case X86::VBROADCASTSSZ256rm:
+ case X86::VBROADCASTSSZrm:
+ FOLD_BROADCAST(32);
+ case X86::VPBROADCASTQZ128rm:
+ case X86::VPBROADCASTQZ256rm:
+ case X86::VPBROADCASTQZrm:
+ case X86::VBROADCASTSDZ256rm:
+ case X86::VBROADCASTSDZrm:
+ FOLD_BROADCAST(64);
default: {
if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))
return nullptr;
@@ -8077,6 +8132,78 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
/*Size=*/0, Alignment, /*AllowCommute=*/true);
}
+MachineInstr *X86InstrInfo::foldMemoryBroadcast(
+ MachineFunction &MF, MachineInstr &MI, unsigned OpNum,
+ ArrayRef<MachineOperand> MOs, MachineBasicBlock::iterator InsertPt,
+ unsigned BitsSize, Align Alignment, bool AllowCommute) const {
+
+ if (auto *I = lookupBroadcastFoldTable(MI.getOpcode(), OpNum))
+ return matchBroadcastSize(*I, BitsSize)
+ ? FuseInst(MF, I->DstOp, OpNum, MOs, InsertPt, MI, *this)
+ : nullptr;
+
+ // TODO: Share code with foldMemoryOperandImpl for the commute
+ if (AllowCommute) {
+ unsigned CommuteOpIdx1 = OpNum, CommuteOpIdx2 = CommuteAnyOperandIndex;
+ if (findCommutedOpIndices(MI, CommuteOpIdx1, CommuteOpIdx2)) {
+ bool HasDef = MI.getDesc().getNumDefs();
+ Register Reg0 = HasDef ? MI.getOperand(0).getReg() : Register();
+ Register Reg1 = MI.getOperand(CommuteOpIdx1).getReg();
+ Register Reg2 = MI.getOperand(CommuteOpIdx2).getReg();
+ bool Tied1 =
+ 0 == MI.getDesc().getOperandConstraint(CommuteOpIdx1, MCOI::TIED_TO);
+ bool Tied2 =
+ 0 == MI.getDesc().getOperandConstraint(CommuteOpIdx2, MCOI::TIED_TO);
+
+ // If either of the commutable operands are tied to the destination
+ // then we can not commute + fold.
+ if ((HasDef && Reg0 == Reg1 && Tied1) ||
+ (HasDef && Reg0 == Reg2 && Tied2))
+ return nullptr;
+
+ MachineInstr *CommutedMI =
+ commuteInstruction(MI, false, CommuteOpIdx1, CommuteOpIdx2);
+ if (!CommutedMI) {
+ // Unable to commute.
+ return nullptr;
+ }
+ if (CommutedMI != &MI) {
+ // New instruction. We can't fold from this.
+ CommutedMI->eraseFromParent();
+ return nullptr;
+ }
+
+ // Attempt to fold with the commuted version of the instruction.
+ MachineInstr *NewMI = foldMemoryBroadcast(MF, MI, CommuteOpIdx2, MOs,
+ InsertPt, BitsSize, Alignment,
+ /*AllowCommute=*/false);
+ if (NewMI)
+ return NewMI;
+
+ // Folding failed again - undo the commute before returning.
+ MachineInstr *UncommutedMI =
+ commuteInstruction(MI, false, CommuteOpIdx1, CommuteOpIdx2);
+ if (!UncommutedMI) {
+ // Unable to commute.
+ return nullptr;
+ }
+ if (UncommutedMI != &MI) {
+ // New instruction. It doesn't need to be kept.
+ UncommutedMI->eraseFromParent();
+ return nullptr;
+ }
+
+ // Return here to prevent duplicate fuse failure report.
+ return nullptr;
+ }
+ }
+
+ // No fusion
+ if (PrintFailedFusing && !MI.isCopy())
+ dbgs() << "We failed to fuse operand " << OpNum << " in " << MI;
+ return nullptr;
+}
+
static SmallVector<MachineMemOperand *, 2>
extractLoadMMOs(ArrayRef<MachineMemOperand *> MMOs, MachineFunction &MF) {
SmallVector<MachineMemOperand *, 2> LoadMMOs;
@@ -8167,6 +8294,10 @@ bool X86InstrInfo::unfoldMemoryOperand(
unsigned Index = I->Flags & TB_INDEX_MASK;
bool FoldedLoad = I->Flags & TB_FOLDED_LOAD;
bool FoldedStore = I->Flags & TB_FOLDED_STORE;
+ // FIXME: Support TB_BCAST_SH in getBroadcastOpcode?
+ if ((I->Flags & TB_BCAST_MASK) == TB_BCAST_SH)
+ return false;
+
if (UnfoldLoad && !FoldedLoad)
return false;
UnfoldLoad &= FoldedLoad;
@@ -8316,6 +8447,9 @@ bool X86InstrInfo::unfoldMemoryOperand(
unsigned Index = I->Flags & TB_INDEX_MASK;
bool FoldedLoad = I->Flags & TB_FOLDED_LOAD;
bool FoldedStore = I->Flags & TB_FOLDED_STORE;
+ // FIXME: Support TB_BCAST_SH in getBroadcastOpcode?
+ if ((I->Flags & TB_BCAST_MASK) == TB_BCAST_SH)
+ return false;
const MCInstrDesc &MCID = get(Opc);
MachineFunction &MF = DAG.getMachineFunction();
const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h
index 0cb69050656109..3a1f98a005ca3a 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.h
+++ b/llvm/lib/Target/X86/X86InstrInfo.h
@@ -643,6 +643,13 @@ class X86InstrInfo final : public X86GenInstrInfo {
MachineBasicBlock::iterator InsertPt,
unsigned Size, Align Alignment) const;
+ MachineInstr *foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
+ unsigned OpNum,
+ ArrayRef<MachineOperand> MOs,
+ MachineBasicBlock::iterator InsertPt,
+ unsigned BitsSize, Align Alignment,
+ bool AllowCommute) const;
+
/// isFrameOperand - Return true and the FrameIndex if the specified
/// operand and follow operands form a reference to the stack frame.
bool isFrameOperand(const MachineInstr &MI, unsigned int Op,
More information about the llvm-commits
mailing list