[llvm] Fix the Gather's base and index by modifying the Scale value (PR #134979)

Rohit Aggarwal via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 9 02:04:22 PDT 2025


https://github.com/rohitaggarwal007 created https://github.com/llvm/llvm-project/pull/134979

Fix the Gather's base and index for one use or multiple uses of Index Node. Using the approach to update the Scale if SHL Opcode and  followed by truncate.


>From 741acb05b09ff5333c0166f191e4ad2ffab88496 Mon Sep 17 00:00:00 2001
From: Rohit Aggarwal <Rohit.Aggarwal at amd.com>
Date: Wed, 19 Mar 2025 15:04:43 +0530
Subject: [PATCH] Fix the Gather's base and index for one use or multiple uses
 of Index Node. Using the approach to update the Scale if SHL Opcode and 
 followed by truncate.

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |   4 +-
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 143 ++++++++++++++++++
 llvm/test/CodeGen/X86/gatherBaseIndexFix.ll   |  68 +++++++++
 3 files changed, 213 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/X86/gatherBaseIndexFix.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 38376de5783ae..7c51ee8222512 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12131,8 +12131,8 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
   if (IndexIsScaled)
     return false;
 
-  if (!isNullConstant(BasePtr) && !Index.hasOneUse())
-    return false;
+  //  if (!isNullConstant(BasePtr) && !Index.hasOneUse())
+  //    return false;
 
   EVT VT = BasePtr.getValueType();
 
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 47ac1ee571269..61e6d0734f402 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56512,6 +56512,120 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
                               Scatter->isTruncatingStore());
 }
 
+// Target override this function to decide whether it want to update the base
+// and index value of a non-uniform gep
+static bool updateBaseAndIndex(SDValue &Base, SDValue &Index, SDValue &Scale,
+                               const SDLoc &DL, const SDValue &Gep,
+                               SelectionDAG &DAG) {
+  SDValue Nbase;
+  SDValue Nindex;
+  SDValue NScale;
+  bool Changed = false;
+  // This function check the opcode of Index and update the index
+  auto checkAndUpdateIndex = [&](SDValue &Idx) {
+    if (Idx.getOpcode() == ISD::SHL) {  // shl zext, BV
+      SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
+      SDValue Op11 = Idx.getOperand(1); // Build vector of constant
+      std::optional<uint64_t> ShAmt = DAG.getValidMinimumShiftAmount(Idx);
+
+      unsigned IndexWidth = Op10.getScalarValueSizeInBits();
+      if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
+           Op10.getOpcode() == ISD::ZERO_EXTEND) &&
+          IndexWidth > 32 &&
+          Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
+          DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) && ShAmt) {
+
+        KnownBits ExtKnown = DAG.computeKnownBits(Op10);
+        bool ExtIsNonNegative = ExtKnown.isNonNegative();
+        KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
+        bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
+        if (!ExtIsNonNegative || !ExtOpIsNonNegative)
+          return false;
+
+        SDValue NewOp10 =
+            Op10.getOperand(0);          // Get the Operand zero from the ext
+        EVT VT = NewOp10.getValueType(); // Use the operand's type to determine
+                                         // the type of index
+
+        // auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
+        // if (!ConstEltNo)
+        //   return false;
+        uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
+        uint64_t NewScaleAmt = ScaleAmt * (1ULL << *ShAmt);
+        LLVM_DEBUG(dbgs() << NewScaleAmt << " NewScaleAmt"
+                          << "\n");
+        if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) {
+          // Nindex = NewOp10.getOperand(0);
+          Nindex = Op10;
+          NScale = DAG.getTargetConstant(NewScaleAmt, DL, Scale.getValueType());
+          return true;
+        }
+        // SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(),
+        //                            DAG.getConstant(ConstEltNo->getZExtValue(),
+        //                                            DL, VT.getScalarType()));
+        // Nindex = DAG.getNode(ISD::SHL, DL, VT, NewOp10,
+        //                     DAG.getBuildVector(VT, DL, Ops));
+      }
+    }
+    return false;
+  };
+
+  // For the gep instruction, we are trying to properly assign the base and
+  // index value We are go through the lower code and iterate backward.
+  if (isNullConstant(Base) && Gep.getOpcode() == ISD::ADD) {
+    SDValue Op0 = Gep.getOperand(0); // base or  add
+    SDValue Op1 = Gep.getOperand(1); // build vector or SHL
+    Nbase = Op0;
+    SDValue Idx = Op1;
+    auto Flags = Gep->getFlags();
+
+    if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx)
+      SDValue Op00 = Op0.getOperand(0); // Base
+      Nbase = Op00;
+      Idx = Op0.getOperand(1);
+    } else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR &&
+                 Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) {
+      return false;
+    }
+    if (!checkAndUpdateIndex(Idx)) {
+      return false;
+    }
+    Base = Nbase.getOperand(0);
+
+    if (Op0 != Nbase) {
+      auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
+      if (!ConstEltNo)
+        return false;
+
+      // SmallVector<SDValue, 8> Ops(
+      //    Nindex.getValueType().getVectorNumElements(),
+      //    DAG.getConstant(ConstEltNo->getZExtValue(), DL,
+      //                    Nindex.getValueType().getScalarType()));
+      Base = DAG.getNode(ISD::ADD, DL, Nbase.getOperand(0).getValueType(),
+                         Nbase.getOperand(0), Op1.getOperand(0), Flags);
+    }
+    Index = Nindex;
+    Scale = NScale;
+    Changed = true;
+  } else if (Base.getOpcode() == ISD::CopyFromReg ||
+             (Base.getOpcode() == ISD::ADD &&
+              Base.getOperand(0).getOpcode() == ISD::CopyFromReg &&
+              isConstOrConstSplat(Base.getOperand(1)))) {
+    if (checkAndUpdateIndex(Index)) {
+      Index = Nindex;
+      Changed = true;
+    }
+  }
+  if (Changed) {
+    LLVM_DEBUG(dbgs() << "Successful in updating the non uniform gep "
+                         "information\n";
+               dbgs() << "updated base "; Base.dump();
+               dbgs() << "updated Index "; Index.dump(););
+    return true;
+  }
+  return false;
+}
+
 static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
                                     TargetLowering::DAGCombinerInfo &DCI) {
   SDLoc DL(N);
@@ -56523,6 +56637,29 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   if (DCI.isBeforeLegalize()) {
+    //    if (updateBaseAndIndex(Base, Index, Scale, DL, Index, DAG))
+    //      return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+    //
+
+    // Attempt to move shifted index into the address scale, allows further
+    // index truncation below.
+    if (Index.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Scale)) {
+      uint64_t ScaleAmt = Scale->getAsZExtVal();
+      if (auto MinShAmt = DAG.getValidMinimumShiftAmount(Index)) {
+        if (*MinShAmt >= 1 && ScaleAmt < 8 &&
+            DAG.ComputeNumSignBits(Index.getOperand(0)) > 1) {
+          SDValue ShAmt = Index.getOperand(1);
+          SDValue NewShAmt =
+              DAG.getNode(ISD::SUB, DL, ShAmt.getValueType(), ShAmt,
+                          DAG.getConstant(1, DL, ShAmt.getValueType()));
+          SDValue NewIndex = DAG.getNode(ISD::SHL, DL, Index.getValueType(),
+                                         Index.getOperand(0), NewShAmt);
+          SDValue NewScale =
+              DAG.getConstant(ScaleAmt * 2, DL, Scale.getValueType());
+          return rebuildGatherScatter(GorS, NewIndex, Base, NewScale, DAG);
+        }
+      }
+    }
     unsigned IndexWidth = Index.getScalarValueSizeInBits();
 
     // Shrink indices if they are larger than 32-bits.
@@ -56552,6 +56689,12 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
         Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
         return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
       }
+
+      // Shrink if we remove an illegal type.
+      if (!TLI.isTypeLegal(Index.getValueType()) && TLI.isTypeLegal(NewVT)) {
+        Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
+        return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+      }
     }
   }
 
diff --git a/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
new file mode 100644
index 0000000000000..faa83b0a20290
--- /dev/null
+++ b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
@@ -0,0 +1,68 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc  -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
+
+%struct.pt = type { float, float, float, i32 }
+%struct.res = type {<16 x float>, <16 x float>}
+
+define <16 x float> @test_gather_16f32_1(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0)  {
+; CHECK-LABEL: test_gather_16f32_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT:    vmovdqu64 (%rsi), %zmm2
+; CHECK-NEXT:    vpmovb2m %xmm0, %k1
+; CHECK-NEXT:    vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm2, %zmm0
+; CHECK-NEXT:    vpaddd %zmm0, %zmm0, %zmm0
+; CHECK-NEXT:    vgatherdps (%rdi,%zmm0,8), %zmm1 {%k1}
+; CHECK-NEXT:    vmovaps %zmm1, %zmm0
+; CHECK-NEXT:    retq
+  %wide.load = load <16 x i32>, ptr %arr, align 4
+  %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+  %zext = zext <16 x i32> %and to <16 x i64>
+  %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext
+  %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+  ret <16 x float> %res
+  }
+
+define <16 x float> @test_gather_16f32_2(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0)  {
+; CHECK-LABEL: test_gather_16f32_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT:    vmovdqu64 (%rsi), %zmm2
+; CHECK-NEXT:    vpmovb2m %xmm0, %k1
+; CHECK-NEXT:    vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm2, %zmm0
+; CHECK-NEXT:    vpaddd %zmm0, %zmm0, %zmm0
+; CHECK-NEXT:    vgatherdps 4(%rdi,%zmm0,8), %zmm1 {%k1}
+; CHECK-NEXT:    vmovaps %zmm1, %zmm0
+; CHECK-NEXT:    retq
+  %wide.load = load <16 x i32>, ptr %arr, align 4
+  %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+  %zext = zext <16 x i32> %and to <16 x i64>
+  %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1
+  %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+  ret <16 x float> %res
+  }
+
+define {<16 x float>, <16 x float>} @test_gather_16f32_3(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0)  {
+; CHECK-LABEL: test_gather_16f32_3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT:    vpmovb2m %xmm0, %k1
+; CHECK-NEXT:    vmovdqu64 (%rsi), %zmm0
+; CHECK-NEXT:    vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm0, %zmm0
+; CHECK-NEXT:    kmovq %k1, %k2
+; CHECK-NEXT:    vpaddd %zmm0, %zmm0, %zmm2
+; CHECK-NEXT:    vmovaps %zmm1, %zmm0
+; CHECK-NEXT:    vgatherdps (%rdi,%zmm2,8), %zmm0 {%k2}
+; CHECK-NEXT:    vgatherdps 4(%rdi,%zmm2,8), %zmm1 {%k1}
+; CHECK-NEXT:    retq
+  %wide.load = load <16 x i32>, ptr %arr, align 4
+  %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+  %zext = zext <16 x i32> %and to <16 x i64>
+  %ptrs1 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext
+  %res1 = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs1, i32 4, <16 x i1> %mask, <16 x float> %src0)
+  %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1
+  %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+  %pair1 = insertvalue {<16 x float>, <16 x float>} undef, <16 x float> %res1, 0
+  %pair2 = insertvalue {<16 x float>, <16 x float>} %pair1, <16 x float> %res, 1
+  ret {<16 x float>, <16 x float>} %pair2
+  }



More information about the llvm-commits mailing list