[llvm] 2665b2a - [X86] Pull out combineConstantPoolLoads helper from combineLoad. NFC.
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Wed May 29 10:06:39 PDT 2024
Author: Simon Pilgrim
Date: 2024-05-29T18:05:41+01:00
New Revision: 2665b2a6ddb1625799536c45ca15605a6f24c081
URL: https://github.com/llvm/llvm-project/commit/2665b2a6ddb1625799536c45ca15605a6f24c081
DIFF: https://github.com/llvm/llvm-project/commit/2665b2a6ddb1625799536c45ca15605a6f24c081.diff
LOG: [X86] Pull out combineConstantPoolLoads helper from combineLoad. NFC.
The logic is already pretty dense and a future patch will further complicate this.
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 2d8343ffa1a0b..24340e135b08b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -50823,10 +50823,83 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
return SDValue();
}
+static SDValue combineConstantPoolLoads(SDNode *N, const SDLoc &dl,
+ SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget &Subtarget) {
+ auto *Ld = cast<LoadSDNode>(N);
+ EVT RegVT = Ld->getValueType(0);
+ EVT MemVT = Ld->getMemoryVT();
+ SDValue Ptr = Ld->getBasePtr();
+ SDValue Chain = Ld->getChain();
+ ISD::LoadExtType Ext = Ld->getExtensionType();
+
+ if (Ext != ISD::NON_EXTLOAD || !Subtarget.hasAVX() || !Ld->isSimple())
+ return SDValue();
+
+ if (!(RegVT.is128BitVector() || RegVT.is256BitVector()))
+ return SDValue();
+
+ auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
+ ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
+ for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
+ if (Undefs[I])
+ continue;
+ if (UserUndefs[I] || Bits[I] != UserBits[I])
+ return false;
+ }
+ return true;
+ };
+
+ // Look through all other loads/broadcasts in the chain for another constant
+ // pool entry.
+ for (SDNode *User : Chain->uses()) {
+ auto *UserLd = dyn_cast<MemSDNode>(User);
+ if (User != N && UserLd &&
+ (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
+ User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
+ ISD::isNormalLoad(User)) &&
+ UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
+ User->getValueSizeInBits(0).getFixedValue() >
+ RegVT.getFixedSizeInBits()) {
+ EVT UserVT = User->getValueType(0);
+ SDValue UserPtr = UserLd->getBasePtr();
+ const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
+ const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
+
+ // See if we are loading a constant that matches in the lower
+ // bits of a longer constant (but from a
diff erent constant pool ptr).
+ if (LdC && UserC && UserPtr != Ptr) {
+ unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
+ unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
+ if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
+ APInt Undefs, UserUndefs;
+ SmallVector<APInt> Bits, UserBits;
+ unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
+ UserVT.getScalarSizeInBits());
+ if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
+ Bits) &&
+ getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
+ UserUndefs, UserBits)) {
+ if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
+ SDValue Extract = extractSubVector(
+ SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
+ Extract = DAG.getBitcast(RegVT, Extract);
+ return DCI.CombineTo(N, Extract, SDValue(User, 1));
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return SDValue();
+}
+
static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
- LoadSDNode *Ld = cast<LoadSDNode>(N);
+ auto *Ld = cast<LoadSDNode>(N);
EVT RegVT = Ld->getValueType(0);
EVT MemVT = Ld->getMemoryVT();
SDLoc dl(Ld);
@@ -50885,7 +50958,7 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
}
}
- // If we also load/broadcast this to a wider type, then just extract the
+ // If we also broadcast this vector to a wider type, then just extract the
// lowest subvector.
if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() &&
(RegVT.is128BitVector() || RegVT.is256BitVector())) {
@@ -50894,61 +50967,23 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
for (SDNode *User : Chain->uses()) {
auto *UserLd = dyn_cast<MemSDNode>(User);
if (User != N && UserLd &&
- (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
- User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
- ISD::isNormalLoad(User)) &&
- UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
+ User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
+ UserLd->getChain() == Chain && UserLd->getBasePtr() == Ptr &&
+ UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits() &&
+ !User->hasAnyUseOfValue(1) &&
User->getValueSizeInBits(0).getFixedValue() >
RegVT.getFixedSizeInBits()) {
- if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
- UserLd->getBasePtr() == Ptr &&
- UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits()) {
- SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
- RegVT.getSizeInBits());
- Extract = DAG.getBitcast(RegVT, Extract);
- return DCI.CombineTo(N, Extract, SDValue(User, 1));
- }
- auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
- ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
- for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
- if (Undefs[I])
- continue;
- if (UserUndefs[I] || Bits[I] != UserBits[I])
- return false;
- }
- return true;
- };
- // See if we are loading a constant that matches in the lower
- // bits of a longer constant (but from a
diff erent constant pool ptr).
- EVT UserVT = User->getValueType(0);
- SDValue UserPtr = UserLd->getBasePtr();
- const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
- const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
- if (LdC && UserC && UserPtr != Ptr) {
- unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
- unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
- if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
- APInt Undefs, UserUndefs;
- SmallVector<APInt> Bits, UserBits;
- unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
- UserVT.getScalarSizeInBits());
- if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
- Bits) &&
- getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
- UserUndefs, UserBits)) {
- if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
- SDValue Extract = extractSubVector(
- SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
- Extract = DAG.getBitcast(RegVT, Extract);
- return DCI.CombineTo(N, Extract, SDValue(User, 1));
- }
- }
- }
- }
+ SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, dl,
+ RegVT.getSizeInBits());
+ Extract = DAG.getBitcast(RegVT, Extract);
+ return DCI.CombineTo(N, Extract, SDValue(User, 1));
}
}
}
+ if (SDValue V = combineConstantPoolLoads(Ld, dl, DAG, DCI, Subtarget))
+ return V;
+
// Cast ptr32 and ptr64 pointers to the default address space before a load.
unsigned AddrSpace = Ld->getAddressSpace();
if (AddrSpace == X86AS::PTR64 || AddrSpace == X86AS::PTR32_SPTR ||
More information about the llvm-commits
mailing list