[llvm] [X86] combineGatherScatter - split non-constant (add v, (splat b)) indices patterns and add the splat into the (scalar) base address (PR #135201)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 10 08:52:20 PDT 2025


https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/135201

We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value

This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well.

Noticed while reviewing #134979

CC @rohitaggarwal007

>From b74a46cdb32e92c57bda5b9c041d8a73513293b3 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Thu, 10 Apr 2025 16:48:47 +0100
Subject: [PATCH] [X86] combineGatherScatter - split non-constant (add v,
 (splat b)) indices patterns and add the splat into the (scalar) base address

We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value

This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well.

Noticed while reviewing #134979
---
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 69 +++++++++++--------
 .../test/CodeGen/X86/masked_gather_scatter.ll |  8 +--
 2 files changed, 42 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a3c423270f44a..77808608045f9 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56521,6 +56521,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
   SDValue Base = GorS->getBasePtr();
   SDValue Scale = GorS->getScale();
   EVT IndexVT = Index.getValueType();
+  EVT IndexSVT = IndexVT.getVectorElementType();
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   if (DCI.isBeforeLegalize()) {
@@ -56557,41 +56558,51 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
   }
 
   EVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
-  // Try to move splat constant adders from the index operand to the base
+
+  // Try to move splat adders from the index operand to the base
   // pointer operand. Taking care to multiply by the scale. We can only do
   // this when index element type is the same as the pointer type.
   // Otherwise we need to be sure the math doesn't wrap before the scale.
-  if (Index.getOpcode() == ISD::ADD &&
-      IndexVT.getVectorElementType() == PtrVT && isa<ConstantSDNode>(Scale)) {
+  if (Index.getOpcode() == ISD::ADD && IndexSVT == PtrVT &&
+      isa<ConstantSDNode>(Scale)) {
     uint64_t ScaleAmt = Scale->getAsZExtVal();
-    if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(1))) {
-      BitVector UndefElts;
-      if (ConstantSDNode *C = BV->getConstantSplatNode(&UndefElts)) {
-        // FIXME: Allow non-constant?
-        if (UndefElts.none()) {
-          // Apply the scale.
-          APInt Adder = C->getAPIntValue() * ScaleAmt;
-          // Add it to the existing base.
-          Base = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
-                             DAG.getConstant(Adder, DL, PtrVT));
-          Index = Index.getOperand(0);
-          return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
-        }
-      }
 
-      // It's also possible base is just a constant. In that case, just
-      // replace it with 0 and move the displacement into the index.
-      if (BV->isConstant() && isa<ConstantSDNode>(Base) &&
-          isOneConstant(Scale)) {
-        SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
-        // Combine the constant build_vector and the constant base.
-        Splat = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(1), Splat);
-        // Add to the LHS of the original Index add.
-        Index = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(0), Splat);
-        Base = DAG.getConstant(0, DL, Base.getValueType());
-        return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+    for (unsigned I = 0; I != 2; ++I)
+      if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(I))) {
+        BitVector UndefElts;
+        if (SDValue Splat = BV->getSplatValue(&UndefElts)) {
+          if (UndefElts.none()) {
+            // If the splat value is constant we can add the scaled splat value
+            // to the existing base.
+            if (auto *C = dyn_cast<ConstantSDNode>(Splat)) {
+              APInt Adder = C->getAPIntValue() * ScaleAmt;
+              SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
+                                            DAG.getConstant(Adder, DL, PtrVT));
+              SDValue NewIndex = Index.getOperand(1 - I);
+              return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+            }
+            // For non-constant cases, limit this to non-scaled cases.
+            if (ScaleAmt == 1) {
+              SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
+              SDValue NewIndex = Index.getOperand(1 - I);
+              return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+            }
+          }
+        }
+        // It's also possible base is just a constant. In that case, just
+        // replace it with 0 and move the displacement into the index.
+        if (ScaleAmt == 1 && BV->isConstant() && isa<ConstantSDNode>(Base)) {
+          SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
+          // Combine the constant build_vector and the constant base.
+          Splat =
+              DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(I), Splat);
+          // Add to the other half of the original Index add.
+          SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
+                                         Index.getOperand(1 - I), Splat);
+          SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
+          return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+        }
       }
-    }
   }
 
   if (DCI.isBeforeLegalizeOps()) {
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index 5effb18fb6aa6..46e589b7b1be9 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -5028,12 +5028,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
 ; X86-KNL-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-KNL-NEXT:    movl {{[0-9]+}}(%esp), %ecx
 ; X86-KNL-NEXT:    vpslld $4, (%ecx), %zmm2
-; X86-KNL-NEXT:    vpbroadcastd %eax, %zmm0
-; X86-KNL-NEXT:    vpaddd %zmm2, %zmm0, %zmm3
 ; X86-KNL-NEXT:    kmovw %k1, %k2
 ; X86-KNL-NEXT:    vmovaps %zmm1, %zmm0
 ; X86-KNL-NEXT:    vgatherdps (%eax,%zmm2), %zmm0 {%k2}
-; X86-KNL-NEXT:    vgatherdps 4(,%zmm3), %zmm1 {%k1}
+; X86-KNL-NEXT:    vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
 ; X86-KNL-NEXT:    retl
 ;
 ; X64-SKX-SMALL-LABEL: test_gather_16f32_mask_index_pair:
@@ -5097,12 +5095,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
 ; X86-SKX-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-SKX-NEXT:    movl {{[0-9]+}}(%esp), %ecx
 ; X86-SKX-NEXT:    vpslld $4, (%ecx), %zmm2
-; X86-SKX-NEXT:    vpbroadcastd %eax, %zmm0
-; X86-SKX-NEXT:    vpaddd %zmm2, %zmm0, %zmm3
 ; X86-SKX-NEXT:    kmovw %k1, %k2
 ; X86-SKX-NEXT:    vmovaps %zmm1, %zmm0
 ; X86-SKX-NEXT:    vgatherdps (%eax,%zmm2), %zmm0 {%k2}
-; X86-SKX-NEXT:    vgatherdps 4(,%zmm3), %zmm1 {%k1}
+; X86-SKX-NEXT:    vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
 ; X86-SKX-NEXT:    retl
   %wide.load = load <16 x i32>, ptr %arr, align 4
   %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>



More information about the llvm-commits mailing list