[llvm] e8900df - [X86] Move gather/scatter index shl(x, c) -> index:x, scale:c fold into X86DAGToDAGISel::matchIndexRecursively
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 22 07:29:59 PDT 2023
Author: Simon Pilgrim
Date: 2023-08-22T15:29:44+01:00
New Revision: e8900df8a769c40661daeb226dfbcd7ec9939a85
URL: https://github.com/llvm/llvm-project/commit/e8900df8a769c40661daeb226dfbcd7ec9939a85
DIFF: https://github.com/llvm/llvm-project/commit/e8900df8a769c40661daeb226dfbcd7ec9939a85.diff
LOG: [X86] Move gather/scatter index shl(x,c) -> index:x, scale:c fold into X86DAGToDAGISel::matchIndexRecursively
Added:
Modified:
llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
llvm/lib/Target/X86/X86ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index d5b7fe3aa6cb02..1784a8103a66bb 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -2216,6 +2216,7 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
if (Depth >= SelectionDAG::MaxRecursionDepth)
return N;
+ // index: add(x,c) -> index: x, disp + c
if (CurDAG->isBaseWithConstantOffset(N)) {
auto *AddVal = cast<ConstantSDNode>(N.getOperand(1));
uint64_t Offset = (uint64_t)AddVal->getSExtValue() * AM.Scale;
@@ -2223,6 +2224,24 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
}
+ // index: add(x,x) -> index: x, scale * 2
+ if (N.getOpcode() == ISD::ADD && N.getOperand(0) == N.getOperand(1)) {
+ if (AM.Scale <= 4) {
+ AM.Scale *= 2;
+ return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
+ }
+ }
+
+ // index: shl(x,i) -> index: x, scale * (1 << i)
+ if (N.getOpcode() == X86ISD::VSHLI) {
+ uint64_t ShiftAmt = N.getConstantOperandVal(1);
+ uint64_t ScaleAmt = 1ULL << ShiftAmt;
+ if ((AM.Scale * ScaleAmt) <= 8) {
+ AM.Scale *= ScaleAmt;
+ return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
+ }
+ }
+
// TODO: Handle extensions, shifted masks etc.
return N;
}
@@ -2672,9 +2691,15 @@ bool X86DAGToDAGISel::selectVectorAddr(MemSDNode *Parent, SDValue BasePtr,
SDValue &Index, SDValue &Disp,
SDValue &Segment) {
X86ISelAddressMode AM;
- AM.IndexReg = IndexOp;
AM.Scale = cast<ConstantSDNode>(ScaleOp)->getZExtValue();
+ // Attempt to match index patterns, as long as we're not relying on implicit
+ // sign-extension, which is performed BEFORE scale.
+ if (IndexOp.getScalarValueSizeInBits() == BasePtr.getScalarValueSizeInBits())
+ AM.IndexReg = matchIndexRecursively(IndexOp, AM, 0);
+ else
+ AM.IndexReg = IndexOp;
+
unsigned AddrSpace = Parent->getPointerInfo().getAddrSpace();
if (AddrSpace == X86AS::GS)
AM.Segment = CurDAG->getRegister(X86::GS, MVT::i16);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 0d2f3b00313a74..9f70d6cedb7611 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -52947,43 +52947,10 @@ static SDValue combineTESTP(SDNode *N, SelectionDAG &DAG,
}
static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI,
- const X86Subtarget &Subtarget) {
+ TargetLowering::DAGCombinerInfo &DCI) {
auto *MemOp = cast<X86MaskedGatherScatterSDNode>(N);
- SDValue BasePtr = MemOp->getBasePtr();
- SDValue Index = MemOp->getIndex();
- SDValue Scale = MemOp->getScale();
SDValue Mask = MemOp->getMask();
- // Attempt to fold an index scale into the scale value directly.
- // For smaller indices, implicit sext is performed BEFORE scale, preventing
- // this fold under most circumstances.
- // TODO: Move this into X86DAGToDAGISel::matchVectorAddressRecursively?
- if ((Index.getOpcode() == X86ISD::VSHLI ||
- (Index.getOpcode() == ISD::ADD &&
- Index.getOperand(0) == Index.getOperand(1))) &&
- isa<ConstantSDNode>(Scale) &&
- BasePtr.getScalarValueSizeInBits() == Index.getScalarValueSizeInBits()) {
- unsigned ShiftAmt =
- Index.getOpcode() == ISD::ADD ? 1 : Index.getConstantOperandVal(1);
- uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
- uint64_t NewScaleAmt = ScaleAmt * (1ULL << ShiftAmt);
- if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) {
- SDValue NewIndex = Index.getOperand(0);
- SDValue NewScale =
- DAG.getTargetConstant(NewScaleAmt, SDLoc(N), Scale.getValueType());
- if (N->getOpcode() == X86ISD::MGATHER)
- return getAVX2GatherNode(N->getOpcode(), SDValue(N, 0), DAG,
- MemOp->getOperand(1), Mask,
- MemOp->getBasePtr(), NewIndex, NewScale,
- MemOp->getChain(), Subtarget);
- if (N->getOpcode() == X86ISD::MSCATTER)
- return getScatterNode(N->getOpcode(), SDValue(N, 0), DAG,
- MemOp->getOperand(1), Mask, MemOp->getBasePtr(),
- NewIndex, NewScale, MemOp->getChain(), Subtarget);
- }
- }
-
// With vector masks we only demand the upper bit of the mask.
if (Mask.getScalarValueSizeInBits() != 1) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -55920,8 +55887,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::MOVMSK: return combineMOVMSK(N, DAG, DCI, Subtarget);
case X86ISD::TESTP: return combineTESTP(N, DAG, DCI, Subtarget);
case X86ISD::MGATHER:
- case X86ISD::MSCATTER:
- return combineX86GatherScatter(N, DAG, DCI, Subtarget);
+ case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI);
case ISD::MGATHER:
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
case X86ISD::PCMPEQ:
More information about the llvm-commits
mailing list