[llvm] 6ac4fe8 - [X86] X86FixupVectorConstants.cpp - refactor constant search loop to take array of sorted candidates
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 1 08:07:00 PST 2024
Author: Simon Pilgrim
Date: 2024-02-01T16:06:36Z
New Revision: 6ac4fe8de014336ce66d02ddd07e85db3b8e77a2
URL: https://github.com/llvm/llvm-project/commit/6ac4fe8de014336ce66d02ddd07e85db3b8e77a2
DIFF: https://github.com/llvm/llvm-project/commit/6ac4fe8de014336ce66d02ddd07e85db3b8e77a2.diff
LOG: [X86] X86FixupVectorConstants.cpp - refactor constant search loop to take array of sorted candidates
Pulled out of #79815 - refactors the internal FixupConstant logic to just accept an array of vzload/broadcast candidates that are pre-sorted in ascending constant pool size
Added:
Modified:
llvm/lib/Target/X86/X86FixupVectorConstants.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
index 037a745d632fb..be3c4f0b1564c 100644
--- a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
+++ b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
@@ -216,8 +216,8 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
// Attempt to rebuild a normalized splat vector constant of the requested splat
// width, built up of potentially smaller scalar values.
-static Constant *rebuildSplatableConstant(const Constant *C,
- unsigned SplatBitWidth) {
+static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
+ unsigned SplatBitWidth) {
std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
if (!Splat)
return nullptr;
@@ -238,8 +238,8 @@ static Constant *rebuildSplatableConstant(const Constant *C,
return rebuildConstant(OriginalType->getContext(), SclTy, *Splat, NumSclBits);
}
-static Constant *rebuildZeroUpperConstant(const Constant *C,
- unsigned ScalarBitWidth) {
+static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
+ unsigned ScalarBitWidth) {
Type *Ty = C->getType();
Type *SclTy = Ty->getScalarType();
unsigned NumBits = Ty->getPrimitiveSizeInBits();
@@ -265,8 +265,6 @@ static Constant *rebuildZeroUpperConstant(const Constant *C,
return nullptr;
}
-typedef std::function<Constant *(const Constant *, unsigned)> RebuildFn;
-
bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
MachineBasicBlock &MBB,
MachineInstr &MI) {
@@ -277,43 +275,42 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
bool HasBWI = ST->hasBWI();
bool HasVLX = ST->hasVLX();
- auto FixupConstant =
- [&](unsigned OpBcst256, unsigned OpBcst128, unsigned OpBcst64,
- unsigned OpBcst32, unsigned OpBcst16, unsigned OpBcst8,
- unsigned OpUpper64, unsigned OpUpper32, unsigned OperandNo) {
- assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
- "Unexpected number of operands!");
-
- if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
- // Attempt to detect a suitable splat/vzload from increasing constant
- // bitwidths.
- // Prefer vzload vs broadcast for same bitwidth to avoid domain flips.
- std::tuple<unsigned, unsigned, RebuildFn> FixupLoad[] = {
- {8, OpBcst8, rebuildSplatableConstant},
- {16, OpBcst16, rebuildSplatableConstant},
- {32, OpUpper32, rebuildZeroUpperConstant},
- {32, OpBcst32, rebuildSplatableConstant},
- {64, OpUpper64, rebuildZeroUpperConstant},
- {64, OpBcst64, rebuildSplatableConstant},
- {128, OpBcst128, rebuildSplatableConstant},
- {256, OpBcst256, rebuildSplatableConstant},
- };
- for (auto [BitWidth, Op, RebuildConstant] : FixupLoad) {
- if (Op) {
- // Construct a suitable constant and adjust the MI to use the new
- // constant pool entry.
- if (Constant *NewCst = RebuildConstant(C, BitWidth)) {
- unsigned NewCPI =
- CP->getConstantPoolIndex(NewCst, Align(BitWidth / 8));
- MI.setDesc(TII->get(Op));
- MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
- return true;
- }
- }
+ struct FixupEntry {
+ int Op;
+ int NumCstElts;
+ int BitWidth;
+ std::function<Constant *(const Constant *, unsigned, unsigned)>
+ RebuildConstant;
+ };
+ auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned OperandNo) {
+#ifdef EXPENSIVE_CHECKS
+ assert(llvm::is_sorted(Fixups,
+ [](const FixupEntry &A, const FixupEntry &B) {
+ return (A.NumCstElts * A.BitWidth) <
+ (B.NumCstElts * B.BitWidth);
+ }) &&
+ "Constant fixup table not sorted in ascending constant size");
+#endif
+ assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
+ "Unexpected number of operands!");
+ if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
+ for (const FixupEntry &Fixup : Fixups) {
+ if (Fixup.Op) {
+ // Construct a suitable constant and adjust the MI to use the new
+ // constant pool entry.
+ if (Constant *NewCst =
+ Fixup.RebuildConstant(C, Fixup.NumCstElts, Fixup.BitWidth)) {
+ unsigned NewCPI =
+ CP->getConstantPoolIndex(NewCst, Align(Fixup.BitWidth / 8));
+ MI.setDesc(TII->get(Fixup.Op));
+ MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
+ return true;
}
}
- return false;
- };
+ }
+ }
+ return false;
+ };
// Attempt to convert full width vector loads into broadcast/vzload loads.
switch (Opc) {
@@ -323,82 +320,125 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
case X86::MOVUPDrm:
case X86::MOVUPSrm:
// TODO: SSE3 MOVDDUP Handling
- return FixupConstant(0, 0, 0, 0, 0, 0, X86::MOVSDrm, X86::MOVSSrm, 1);
+ return FixupConstant({{X86::MOVSSrm, 1, 32, rebuildZeroUpperCst},
+ {X86::MOVSDrm, 1, 64, rebuildZeroUpperCst}},
+ 1);
case X86::VMOVAPDrm:
case X86::VMOVAPSrm:
case X86::VMOVUPDrm:
case X86::VMOVUPSrm:
- return FixupConstant(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0,
- X86::VMOVSDrm, X86::VMOVSSrm, 1);
+ return FixupConstant({{X86::VMOVSSrm, 1, 32, rebuildZeroUpperCst},
+ {X86::VBROADCASTSSrm, 1, 32, rebuildSplatCst},
+ {X86::VMOVSDrm, 1, 64, rebuildZeroUpperCst},
+ {X86::VMOVDDUPrm, 1, 64, rebuildSplatCst}},
+ 1);
case X86::VMOVAPDYrm:
case X86::VMOVAPSYrm:
case X86::VMOVUPDYrm:
case X86::VMOVUPSYrm:
- return FixupConstant(0, X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
- X86::VBROADCASTSSYrm, 0, 0, 0, 0, 1);
+ return FixupConstant({{X86::VBROADCASTSSYrm, 1, 32, rebuildSplatCst},
+ {X86::VBROADCASTSDYrm, 1, 64, rebuildSplatCst},
+ {X86::VBROADCASTF128rm, 1, 128, rebuildSplatCst}},
+ 1);
case X86::VMOVAPDZ128rm:
case X86::VMOVAPSZ128rm:
case X86::VMOVUPDZ128rm:
case X86::VMOVUPSZ128rm:
- return FixupConstant(0, 0, X86::VMOVDDUPZ128rm, X86::VBROADCASTSSZ128rm, 0,
- 0, X86::VMOVSDZrm, X86::VMOVSSZrm, 1);
+ return FixupConstant({{X86::VMOVSSZrm, 1, 32, rebuildZeroUpperCst},
+ {X86::VBROADCASTSSZ128rm, 1, 32, rebuildSplatCst},
+ {X86::VMOVSDZrm, 1, 64, rebuildZeroUpperCst},
+ {X86::VMOVDDUPZ128rm, 1, 64, rebuildSplatCst}},
+ 1);
case X86::VMOVAPDZ256rm:
case X86::VMOVAPSZ256rm:
case X86::VMOVUPDZ256rm:
case X86::VMOVUPSZ256rm:
- return FixupConstant(0, X86::VBROADCASTF32X4Z256rm, X86::VBROADCASTSDZ256rm,
- X86::VBROADCASTSSZ256rm, 0, 0, 0, 0, 1);
+ return FixupConstant(
+ {{X86::VBROADCASTSSZ256rm, 1, 32, rebuildSplatCst},
+ {X86::VBROADCASTSDZ256rm, 1, 64, rebuildSplatCst},
+ {X86::VBROADCASTF32X4Z256rm, 1, 128, rebuildSplatCst}},
+ 1);
case X86::VMOVAPDZrm:
case X86::VMOVAPSZrm:
case X86::VMOVUPDZrm:
case X86::VMOVUPSZrm:
- return FixupConstant(X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
- X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0, 0, 0, 0,
+ return FixupConstant({{X86::VBROADCASTSSZrm, 1, 32, rebuildSplatCst},
+ {X86::VBROADCASTSDZrm, 1, 64, rebuildSplatCst},
+ {X86::VBROADCASTF32X4rm, 1, 128, rebuildSplatCst},
+ {X86::VBROADCASTF64X4rm, 1, 256, rebuildSplatCst}},
1);
/* Integer Loads */
case X86::MOVDQArm:
- case X86::MOVDQUrm:
- return FixupConstant(0, 0, 0, 0, 0, 0, X86::MOVQI2PQIrm, X86::MOVDI2PDIrm,
+ case X86::MOVDQUrm: {
+ return FixupConstant({{X86::MOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
+ {X86::MOVQI2PQIrm, 1, 64, rebuildZeroUpperCst}},
1);
+ }
case X86::VMOVDQArm:
- case X86::VMOVDQUrm:
- return FixupConstant(0, 0, HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
- HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
- HasAVX2 ? X86::VPBROADCASTWrm : 0,
- HasAVX2 ? X86::VPBROADCASTBrm : 0, X86::VMOVQI2PQIrm,
- X86::VMOVDI2PDIrm, 1);
+ case X86::VMOVDQUrm: {
+ FixupEntry Fixups[] = {
+ {HasAVX2 ? X86::VPBROADCASTBrm : 0, 1, 8, rebuildSplatCst},
+ {HasAVX2 ? X86::VPBROADCASTWrm : 0, 1, 16, rebuildSplatCst},
+ {X86::VMOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
+ {HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1, 32,
+ rebuildSplatCst},
+ {X86::VMOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
+ {HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1, 64,
+ rebuildSplatCst},
+ };
+ return FixupConstant(Fixups, 1);
+ }
case X86::VMOVDQAYrm:
- case X86::VMOVDQUYrm:
- return FixupConstant(
- 0, HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
- HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
- HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
- HasAVX2 ? X86::VPBROADCASTWYrm : 0, HasAVX2 ? X86::VPBROADCASTBYrm : 0,
- 0, 0, 1);
+ case X86::VMOVDQUYrm: {
+ FixupEntry Fixups[] = {
+ {HasAVX2 ? X86::VPBROADCASTBYrm : 0, 1, 8, rebuildSplatCst},
+ {HasAVX2 ? X86::VPBROADCASTWYrm : 0, 1, 16, rebuildSplatCst},
+ {HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1, 32,
+ rebuildSplatCst},
+ {HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1, 64,
+ rebuildSplatCst},
+ {HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1, 128,
+ rebuildSplatCst}};
+ return FixupConstant(Fixups, 1);
+ }
case X86::VMOVDQA32Z128rm:
case X86::VMOVDQA64Z128rm:
case X86::VMOVDQU32Z128rm:
- case X86::VMOVDQU64Z128rm:
- return FixupConstant(0, 0, X86::VPBROADCASTQZ128rm, X86::VPBROADCASTDZ128rm,
- HasBWI ? X86::VPBROADCASTWZ128rm : 0,
- HasBWI ? X86::VPBROADCASTBZ128rm : 0,
- X86::VMOVQI2PQIZrm, X86::VMOVDI2PDIZrm, 1);
+ case X86::VMOVDQU64Z128rm: {
+ FixupEntry Fixups[] = {
+ {HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1, 8, rebuildSplatCst},
+ {HasBWI ? X86::VPBROADCASTWZ128rm : 0, 1, 16, rebuildSplatCst},
+ {X86::VMOVDI2PDIZrm, 1, 32, rebuildZeroUpperCst},
+ {X86::VPBROADCASTDZ128rm, 1, 32, rebuildSplatCst},
+ {X86::VMOVQI2PQIZrm, 1, 64, rebuildZeroUpperCst},
+ {X86::VPBROADCASTQZ128rm, 1, 64, rebuildSplatCst}};
+ return FixupConstant(Fixups, 1);
+ }
case X86::VMOVDQA32Z256rm:
case X86::VMOVDQA64Z256rm:
case X86::VMOVDQU32Z256rm:
- case X86::VMOVDQU64Z256rm:
- return FixupConstant(0, X86::VBROADCASTI32X4Z256rm, X86::VPBROADCASTQZ256rm,
- X86::VPBROADCASTDZ256rm,
- HasBWI ? X86::VPBROADCASTWZ256rm : 0,
- HasBWI ? X86::VPBROADCASTBZ256rm : 0, 0, 0, 1);
+ case X86::VMOVDQU64Z256rm: {
+ FixupEntry Fixups[] = {
+ {HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1, 8, rebuildSplatCst},
+ {HasBWI ? X86::VPBROADCASTWZ256rm : 0, 1, 16, rebuildSplatCst},
+ {X86::VPBROADCASTDZ256rm, 1, 32, rebuildSplatCst},
+ {X86::VPBROADCASTQZ256rm, 1, 64, rebuildSplatCst},
+ {X86::VBROADCASTI32X4Z256rm, 1, 128, rebuildSplatCst}};
+ return FixupConstant(Fixups, 1);
+ }
case X86::VMOVDQA32Zrm:
case X86::VMOVDQA64Zrm:
case X86::VMOVDQU32Zrm:
- case X86::VMOVDQU64Zrm:
- return FixupConstant(X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
- X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
- HasBWI ? X86::VPBROADCASTWZrm : 0,
- HasBWI ? X86::VPBROADCASTBZrm : 0, 0, 0, 1);
+ case X86::VMOVDQU64Zrm: {
+ FixupEntry Fixups[] = {
+ {HasBWI ? X86::VPBROADCASTBZrm : 0, 1, 8, rebuildSplatCst},
+ {HasBWI ? X86::VPBROADCASTWZrm : 0, 1, 16, rebuildSplatCst},
+ {X86::VPBROADCASTDZrm, 1, 32, rebuildSplatCst},
+ {X86::VPBROADCASTQZrm, 1, 64, rebuildSplatCst},
+ {X86::VBROADCASTI32X4rm, 1, 128, rebuildSplatCst},
+ {X86::VBROADCASTI64X4rm, 1, 256, rebuildSplatCst}};
+ return FixupConstant(Fixups, 1);
+ }
}
auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
@@ -423,7 +463,9 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
if (OpBcst32 || OpBcst64) {
unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
- return FixupConstant(0, 0, OpBcst64, OpBcst32, 0, 0, 0, 0, OpNo);
+ FixupEntry Fixups[] = {{(int)OpBcst32, 32, 32, rebuildSplatCst},
+ {(int)OpBcst64, 64, 64, rebuildSplatCst}};
+ return FixupConstant(Fixups, OpNo);
}
return false;
};
More information about the llvm-commits
mailing list