[llvm] 937cfdc - [X86] combineGatherScatter - split non-constant (add v, (splat b)) indices patterns and add the splat into the (scalar) base address (#135201)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 11 02:53:56 PDT 2025
Author: Simon Pilgrim
Date: 2025-04-11T10:53:53+01:00
New Revision: 937cfdc7be56cb5cb46c098eecdb8e4f524add0d
URL: https://github.com/llvm/llvm-project/commit/937cfdc7be56cb5cb46c098eecdb8e4f524add0d
DIFF: https://github.com/llvm/llvm-project/commit/937cfdc7be56cb5cb46c098eecdb8e4f524add0d.diff
LOG: [X86] combineGatherScatter - split non-constant (add v, (splat b)) indices patterns and add the splat into the (scalar) base address (#135201)
We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value
This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well.
Noticed while reviewing #134979
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/masked_gather_scatter.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6e6431008e680..4e94d2cef6bd8 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56517,6 +56517,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
SDValue Base = GorS->getBasePtr();
SDValue Scale = GorS->getScale();
EVT IndexVT = Index.getValueType();
+ EVT IndexSVT = IndexVT.getVectorElementType();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (DCI.isBeforeLegalize()) {
@@ -56553,41 +56554,51 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
}
EVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
- // Try to move splat constant adders from the index operand to the base
+
+ // Try to move splat adders from the index operand to the base
// pointer operand. Taking care to multiply by the scale. We can only do
// this when index element type is the same as the pointer type.
// Otherwise we need to be sure the math doesn't wrap before the scale.
- if (Index.getOpcode() == ISD::ADD &&
- IndexVT.getVectorElementType() == PtrVT && isa<ConstantSDNode>(Scale)) {
+ if (Index.getOpcode() == ISD::ADD && IndexSVT == PtrVT &&
+ isa<ConstantSDNode>(Scale)) {
uint64_t ScaleAmt = Scale->getAsZExtVal();
- if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(1))) {
- BitVector UndefElts;
- if (ConstantSDNode *C = BV->getConstantSplatNode(&UndefElts)) {
- // FIXME: Allow non-constant?
- if (UndefElts.none()) {
- // Apply the scale.
- APInt Adder = C->getAPIntValue() * ScaleAmt;
- // Add it to the existing base.
- Base = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
- DAG.getConstant(Adder, DL, PtrVT));
- Index = Index.getOperand(0);
- return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
- }
- }
- // It's also possible base is just a constant. In that case, just
- // replace it with 0 and move the displacement into the index.
- if (BV->isConstant() && isa<ConstantSDNode>(Base) &&
- isOneConstant(Scale)) {
- SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
- // Combine the constant build_vector and the constant base.
- Splat = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(1), Splat);
- // Add to the LHS of the original Index add.
- Index = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(0), Splat);
- Base = DAG.getConstant(0, DL, Base.getValueType());
- return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ for (unsigned I = 0; I != 2; ++I)
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(I))) {
+ BitVector UndefElts;
+ if (SDValue Splat = BV->getSplatValue(&UndefElts)) {
+ if (UndefElts.none()) {
+ // If the splat value is constant we can add the scaled splat value
+ // to the existing base.
+ if (auto *C = dyn_cast<ConstantSDNode>(Splat)) {
+ APInt Adder = C->getAPIntValue() * ScaleAmt;
+ SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
+ DAG.getConstant(Adder, DL, PtrVT));
+ SDValue NewIndex = Index.getOperand(1 - I);
+ return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ }
+ // For non-constant cases, limit this to non-scaled cases.
+ if (ScaleAmt == 1) {
+ SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
+ SDValue NewIndex = Index.getOperand(1 - I);
+ return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ }
+ }
+ }
+ // It's also possible base is just a constant. In that case, just
+ // replace it with 0 and move the displacement into the index.
+ if (ScaleAmt == 1 && BV->isConstant() && isa<ConstantSDNode>(Base)) {
+ SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
+ // Combine the constant build_vector and the constant base.
+ Splat =
+ DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(I), Splat);
+ // Add to the other half of the original Index add.
+ SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
+ Index.getOperand(1 - I), Splat);
+ SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
+ return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ }
}
- }
}
if (DCI.isBeforeLegalizeOps()) {
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index 5effb18fb6aa6..46e589b7b1be9 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -5028,12 +5028,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %ecx
; X86-KNL-NEXT: vpslld $4, (%ecx), %zmm2
-; X86-KNL-NEXT: vpbroadcastd %eax, %zmm0
-; X86-KNL-NEXT: vpaddd %zmm2, %zmm0, %zmm3
; X86-KNL-NEXT: kmovw %k1, %k2
; X86-KNL-NEXT: vmovaps %zmm1, %zmm0
; X86-KNL-NEXT: vgatherdps (%eax,%zmm2), %zmm0 {%k2}
-; X86-KNL-NEXT: vgatherdps 4(,%zmm3), %zmm1 {%k1}
+; X86-KNL-NEXT: vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
; X86-KNL-NEXT: retl
;
; X64-SKX-SMALL-LABEL: test_gather_16f32_mask_index_pair:
@@ -5097,12 +5095,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %ecx
; X86-SKX-NEXT: vpslld $4, (%ecx), %zmm2
-; X86-SKX-NEXT: vpbroadcastd %eax, %zmm0
-; X86-SKX-NEXT: vpaddd %zmm2, %zmm0, %zmm3
; X86-SKX-NEXT: kmovw %k1, %k2
; X86-SKX-NEXT: vmovaps %zmm1, %zmm0
; X86-SKX-NEXT: vgatherdps (%eax,%zmm2), %zmm0 {%k2}
-; X86-SKX-NEXT: vgatherdps 4(,%zmm3), %zmm1 {%k1}
+; X86-SKX-NEXT: vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
; X86-SKX-NEXT: retl
%wide.load = load <16 x i32>, ptr %arr, align 4
%and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
More information about the llvm-commits
mailing list