[llvm] [X86][SelectionDAG] Fix the Gather's base and index by modifying the Scale value (PR #134979)
Rohit Aggarwal via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 9 04:33:48 PDT 2025
https://github.com/rohitaggarwal007 updated https://github.com/llvm/llvm-project/pull/134979
>From 741acb05b09ff5333c0166f191e4ad2ffab88496 Mon Sep 17 00:00:00 2001
From: Rohit Aggarwal <Rohit.Aggarwal at amd.com>
Date: Wed, 19 Mar 2025 15:04:43 +0530
Subject: [PATCH 1/4] Fix the Gather's base and index for one use or multiple
uses of Index Node. Using the approach to update the Scale if SHL Opcode and
followed by truncate.
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 4 +-
llvm/lib/Target/X86/X86ISelLowering.cpp | 143 ++++++++++++++++++
llvm/test/CodeGen/X86/gatherBaseIndexFix.ll | 68 +++++++++
3 files changed, 213 insertions(+), 2 deletions(-)
create mode 100644 llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 38376de5783ae..7c51ee8222512 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12131,8 +12131,8 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
if (IndexIsScaled)
return false;
- if (!isNullConstant(BasePtr) && !Index.hasOneUse())
- return false;
+ // if (!isNullConstant(BasePtr) && !Index.hasOneUse())
+ // return false;
EVT VT = BasePtr.getValueType();
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 47ac1ee571269..61e6d0734f402 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56512,6 +56512,120 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
Scatter->isTruncatingStore());
}
+// Target override this function to decide whether it want to update the base
+// and index value of a non-uniform gep
+static bool updateBaseAndIndex(SDValue &Base, SDValue &Index, SDValue &Scale,
+ const SDLoc &DL, const SDValue &Gep,
+ SelectionDAG &DAG) {
+ SDValue Nbase;
+ SDValue Nindex;
+ SDValue NScale;
+ bool Changed = false;
+ // This function check the opcode of Index and update the index
+ auto checkAndUpdateIndex = [&](SDValue &Idx) {
+ if (Idx.getOpcode() == ISD::SHL) { // shl zext, BV
+ SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
+ SDValue Op11 = Idx.getOperand(1); // Build vector of constant
+ std::optional<uint64_t> ShAmt = DAG.getValidMinimumShiftAmount(Idx);
+
+ unsigned IndexWidth = Op10.getScalarValueSizeInBits();
+ if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
+ Op10.getOpcode() == ISD::ZERO_EXTEND) &&
+ IndexWidth > 32 &&
+ Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
+ DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) && ShAmt) {
+
+ KnownBits ExtKnown = DAG.computeKnownBits(Op10);
+ bool ExtIsNonNegative = ExtKnown.isNonNegative();
+ KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
+ bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
+ if (!ExtIsNonNegative || !ExtOpIsNonNegative)
+ return false;
+
+ SDValue NewOp10 =
+ Op10.getOperand(0); // Get the Operand zero from the ext
+ EVT VT = NewOp10.getValueType(); // Use the operand's type to determine
+ // the type of index
+
+ // auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
+ // if (!ConstEltNo)
+ // return false;
+ uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
+ uint64_t NewScaleAmt = ScaleAmt * (1ULL << *ShAmt);
+ LLVM_DEBUG(dbgs() << NewScaleAmt << " NewScaleAmt"
+ << "\n");
+ if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) {
+ // Nindex = NewOp10.getOperand(0);
+ Nindex = Op10;
+ NScale = DAG.getTargetConstant(NewScaleAmt, DL, Scale.getValueType());
+ return true;
+ }
+ // SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(),
+ // DAG.getConstant(ConstEltNo->getZExtValue(),
+ // DL, VT.getScalarType()));
+ // Nindex = DAG.getNode(ISD::SHL, DL, VT, NewOp10,
+ // DAG.getBuildVector(VT, DL, Ops));
+ }
+ }
+ return false;
+ };
+
+ // For the gep instruction, we are trying to properly assign the base and
+ // index value We are go through the lower code and iterate backward.
+ if (isNullConstant(Base) && Gep.getOpcode() == ISD::ADD) {
+ SDValue Op0 = Gep.getOperand(0); // base or add
+ SDValue Op1 = Gep.getOperand(1); // build vector or SHL
+ Nbase = Op0;
+ SDValue Idx = Op1;
+ auto Flags = Gep->getFlags();
+
+ if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx)
+ SDValue Op00 = Op0.getOperand(0); // Base
+ Nbase = Op00;
+ Idx = Op0.getOperand(1);
+ } else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR &&
+ Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) {
+ return false;
+ }
+ if (!checkAndUpdateIndex(Idx)) {
+ return false;
+ }
+ Base = Nbase.getOperand(0);
+
+ if (Op0 != Nbase) {
+ auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
+ if (!ConstEltNo)
+ return false;
+
+ // SmallVector<SDValue, 8> Ops(
+ // Nindex.getValueType().getVectorNumElements(),
+ // DAG.getConstant(ConstEltNo->getZExtValue(), DL,
+ // Nindex.getValueType().getScalarType()));
+ Base = DAG.getNode(ISD::ADD, DL, Nbase.getOperand(0).getValueType(),
+ Nbase.getOperand(0), Op1.getOperand(0), Flags);
+ }
+ Index = Nindex;
+ Scale = NScale;
+ Changed = true;
+ } else if (Base.getOpcode() == ISD::CopyFromReg ||
+ (Base.getOpcode() == ISD::ADD &&
+ Base.getOperand(0).getOpcode() == ISD::CopyFromReg &&
+ isConstOrConstSplat(Base.getOperand(1)))) {
+ if (checkAndUpdateIndex(Index)) {
+ Index = Nindex;
+ Changed = true;
+ }
+ }
+ if (Changed) {
+ LLVM_DEBUG(dbgs() << "Successful in updating the non uniform gep "
+ "information\n";
+ dbgs() << "updated base "; Base.dump();
+ dbgs() << "updated Index "; Index.dump(););
+ return true;
+ }
+ return false;
+}
+
static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
@@ -56523,6 +56637,29 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (DCI.isBeforeLegalize()) {
+ // if (updateBaseAndIndex(Base, Index, Scale, DL, Index, DAG))
+ // return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ //
+
+ // Attempt to move shifted index into the address scale, allows further
+ // index truncation below.
+ if (Index.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Scale)) {
+ uint64_t ScaleAmt = Scale->getAsZExtVal();
+ if (auto MinShAmt = DAG.getValidMinimumShiftAmount(Index)) {
+ if (*MinShAmt >= 1 && ScaleAmt < 8 &&
+ DAG.ComputeNumSignBits(Index.getOperand(0)) > 1) {
+ SDValue ShAmt = Index.getOperand(1);
+ SDValue NewShAmt =
+ DAG.getNode(ISD::SUB, DL, ShAmt.getValueType(), ShAmt,
+ DAG.getConstant(1, DL, ShAmt.getValueType()));
+ SDValue NewIndex = DAG.getNode(ISD::SHL, DL, Index.getValueType(),
+ Index.getOperand(0), NewShAmt);
+ SDValue NewScale =
+ DAG.getConstant(ScaleAmt * 2, DL, Scale.getValueType());
+ return rebuildGatherScatter(GorS, NewIndex, Base, NewScale, DAG);
+ }
+ }
+ }
unsigned IndexWidth = Index.getScalarValueSizeInBits();
// Shrink indices if they are larger than 32-bits.
@@ -56552,6 +56689,12 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
}
+
+ // Shrink if we remove an illegal type.
+ if (!TLI.isTypeLegal(Index.getValueType()) && TLI.isTypeLegal(NewVT)) {
+ Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
+ return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ }
}
}
diff --git a/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
new file mode 100644
index 0000000000000..faa83b0a20290
--- /dev/null
+++ b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
@@ -0,0 +1,68 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
+
+%struct.pt = type { float, float, float, i32 }
+%struct.res = type {<16 x float>, <16 x float>}
+
+define <16 x float> @test_gather_16f32_1(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
+; CHECK-LABEL: test_gather_16f32_1:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT: vmovdqu64 (%rsi), %zmm2
+; CHECK-NEXT: vpmovb2m %xmm0, %k1
+; CHECK-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm2, %zmm0
+; CHECK-NEXT: vpaddd %zmm0, %zmm0, %zmm0
+; CHECK-NEXT: vgatherdps (%rdi,%zmm0,8), %zmm1 {%k1}
+; CHECK-NEXT: vmovaps %zmm1, %zmm0
+; CHECK-NEXT: retq
+ %wide.load = load <16 x i32>, ptr %arr, align 4
+ %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+ %zext = zext <16 x i32> %and to <16 x i64>
+ %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext
+ %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ ret <16 x float> %res
+ }
+
+define <16 x float> @test_gather_16f32_2(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
+; CHECK-LABEL: test_gather_16f32_2:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT: vmovdqu64 (%rsi), %zmm2
+; CHECK-NEXT: vpmovb2m %xmm0, %k1
+; CHECK-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm2, %zmm0
+; CHECK-NEXT: vpaddd %zmm0, %zmm0, %zmm0
+; CHECK-NEXT: vgatherdps 4(%rdi,%zmm0,8), %zmm1 {%k1}
+; CHECK-NEXT: vmovaps %zmm1, %zmm0
+; CHECK-NEXT: retq
+ %wide.load = load <16 x i32>, ptr %arr, align 4
+ %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+ %zext = zext <16 x i32> %and to <16 x i64>
+ %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1
+ %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ ret <16 x float> %res
+ }
+
+define {<16 x float>, <16 x float>} @test_gather_16f32_3(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
+; CHECK-LABEL: test_gather_16f32_3:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT: vpmovb2m %xmm0, %k1
+; CHECK-NEXT: vmovdqu64 (%rsi), %zmm0
+; CHECK-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm0, %zmm0
+; CHECK-NEXT: kmovq %k1, %k2
+; CHECK-NEXT: vpaddd %zmm0, %zmm0, %zmm2
+; CHECK-NEXT: vmovaps %zmm1, %zmm0
+; CHECK-NEXT: vgatherdps (%rdi,%zmm2,8), %zmm0 {%k2}
+; CHECK-NEXT: vgatherdps 4(%rdi,%zmm2,8), %zmm1 {%k1}
+; CHECK-NEXT: retq
+ %wide.load = load <16 x i32>, ptr %arr, align 4
+ %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+ %zext = zext <16 x i32> %and to <16 x i64>
+ %ptrs1 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext
+ %res1 = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs1, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1
+ %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ %pair1 = insertvalue {<16 x float>, <16 x float>} undef, <16 x float> %res1, 0
+ %pair2 = insertvalue {<16 x float>, <16 x float>} %pair1, <16 x float> %res, 1
+ ret {<16 x float>, <16 x float>} %pair2
+ }
>From 9d395b137899a3c985444e191756a04400b42287 Mon Sep 17 00:00:00 2001
From: Rohit Aggarwal <Rohit.Aggarwal at amd.com>
Date: Wed, 19 Mar 2025 15:04:43 +0530
Subject: [PATCH 2/4] Fix the Gather's base and index for one use or multiple
uses of Index Node. Using the approach to update the Scale if SHL Opcode and
followed by truncate.
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 3 -
llvm/lib/Target/X86/X86ISelLowering.cpp | 143 ++++++++++++++++++
llvm/test/CodeGen/X86/gatherBaseIndexFix.ll | 68 +++++++++
3 files changed, 211 insertions(+), 3 deletions(-)
create mode 100644 llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 38376de5783ae..a727d63c95019 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12131,9 +12131,6 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
if (IndexIsScaled)
return false;
- if (!isNullConstant(BasePtr) && !Index.hasOneUse())
- return false;
-
EVT VT = BasePtr.getValueType();
if (SDValue SplatVal = DAG.getSplatValue(Index);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 47ac1ee571269..61e6d0734f402 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56512,6 +56512,120 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
Scatter->isTruncatingStore());
}
+// Target override this function to decide whether it want to update the base
+// and index value of a non-uniform gep
+static bool updateBaseAndIndex(SDValue &Base, SDValue &Index, SDValue &Scale,
+ const SDLoc &DL, const SDValue &Gep,
+ SelectionDAG &DAG) {
+ SDValue Nbase;
+ SDValue Nindex;
+ SDValue NScale;
+ bool Changed = false;
+ // This function check the opcode of Index and update the index
+ auto checkAndUpdateIndex = [&](SDValue &Idx) {
+ if (Idx.getOpcode() == ISD::SHL) { // shl zext, BV
+ SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
+ SDValue Op11 = Idx.getOperand(1); // Build vector of constant
+ std::optional<uint64_t> ShAmt = DAG.getValidMinimumShiftAmount(Idx);
+
+ unsigned IndexWidth = Op10.getScalarValueSizeInBits();
+ if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
+ Op10.getOpcode() == ISD::ZERO_EXTEND) &&
+ IndexWidth > 32 &&
+ Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
+ DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) && ShAmt) {
+
+ KnownBits ExtKnown = DAG.computeKnownBits(Op10);
+ bool ExtIsNonNegative = ExtKnown.isNonNegative();
+ KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
+ bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
+ if (!ExtIsNonNegative || !ExtOpIsNonNegative)
+ return false;
+
+ SDValue NewOp10 =
+ Op10.getOperand(0); // Get the Operand zero from the ext
+ EVT VT = NewOp10.getValueType(); // Use the operand's type to determine
+ // the type of index
+
+ // auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
+ // if (!ConstEltNo)
+ // return false;
+ uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
+ uint64_t NewScaleAmt = ScaleAmt * (1ULL << *ShAmt);
+ LLVM_DEBUG(dbgs() << NewScaleAmt << " NewScaleAmt"
+ << "\n");
+ if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) {
+ // Nindex = NewOp10.getOperand(0);
+ Nindex = Op10;
+ NScale = DAG.getTargetConstant(NewScaleAmt, DL, Scale.getValueType());
+ return true;
+ }
+ // SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(),
+ // DAG.getConstant(ConstEltNo->getZExtValue(),
+ // DL, VT.getScalarType()));
+ // Nindex = DAG.getNode(ISD::SHL, DL, VT, NewOp10,
+ // DAG.getBuildVector(VT, DL, Ops));
+ }
+ }
+ return false;
+ };
+
+ // For the gep instruction, we are trying to properly assign the base and
+ // index value We are go through the lower code and iterate backward.
+ if (isNullConstant(Base) && Gep.getOpcode() == ISD::ADD) {
+ SDValue Op0 = Gep.getOperand(0); // base or add
+ SDValue Op1 = Gep.getOperand(1); // build vector or SHL
+ Nbase = Op0;
+ SDValue Idx = Op1;
+ auto Flags = Gep->getFlags();
+
+ if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx)
+ SDValue Op00 = Op0.getOperand(0); // Base
+ Nbase = Op00;
+ Idx = Op0.getOperand(1);
+ } else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR &&
+ Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) {
+ return false;
+ }
+ if (!checkAndUpdateIndex(Idx)) {
+ return false;
+ }
+ Base = Nbase.getOperand(0);
+
+ if (Op0 != Nbase) {
+ auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
+ if (!ConstEltNo)
+ return false;
+
+ // SmallVector<SDValue, 8> Ops(
+ // Nindex.getValueType().getVectorNumElements(),
+ // DAG.getConstant(ConstEltNo->getZExtValue(), DL,
+ // Nindex.getValueType().getScalarType()));
+ Base = DAG.getNode(ISD::ADD, DL, Nbase.getOperand(0).getValueType(),
+ Nbase.getOperand(0), Op1.getOperand(0), Flags);
+ }
+ Index = Nindex;
+ Scale = NScale;
+ Changed = true;
+ } else if (Base.getOpcode() == ISD::CopyFromReg ||
+ (Base.getOpcode() == ISD::ADD &&
+ Base.getOperand(0).getOpcode() == ISD::CopyFromReg &&
+ isConstOrConstSplat(Base.getOperand(1)))) {
+ if (checkAndUpdateIndex(Index)) {
+ Index = Nindex;
+ Changed = true;
+ }
+ }
+ if (Changed) {
+ LLVM_DEBUG(dbgs() << "Successful in updating the non uniform gep "
+ "information\n";
+ dbgs() << "updated base "; Base.dump();
+ dbgs() << "updated Index "; Index.dump(););
+ return true;
+ }
+ return false;
+}
+
static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
@@ -56523,6 +56637,29 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (DCI.isBeforeLegalize()) {
+ // if (updateBaseAndIndex(Base, Index, Scale, DL, Index, DAG))
+ // return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ //
+
+ // Attempt to move shifted index into the address scale, allows further
+ // index truncation below.
+ if (Index.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Scale)) {
+ uint64_t ScaleAmt = Scale->getAsZExtVal();
+ if (auto MinShAmt = DAG.getValidMinimumShiftAmount(Index)) {
+ if (*MinShAmt >= 1 && ScaleAmt < 8 &&
+ DAG.ComputeNumSignBits(Index.getOperand(0)) > 1) {
+ SDValue ShAmt = Index.getOperand(1);
+ SDValue NewShAmt =
+ DAG.getNode(ISD::SUB, DL, ShAmt.getValueType(), ShAmt,
+ DAG.getConstant(1, DL, ShAmt.getValueType()));
+ SDValue NewIndex = DAG.getNode(ISD::SHL, DL, Index.getValueType(),
+ Index.getOperand(0), NewShAmt);
+ SDValue NewScale =
+ DAG.getConstant(ScaleAmt * 2, DL, Scale.getValueType());
+ return rebuildGatherScatter(GorS, NewIndex, Base, NewScale, DAG);
+ }
+ }
+ }
unsigned IndexWidth = Index.getScalarValueSizeInBits();
// Shrink indices if they are larger than 32-bits.
@@ -56552,6 +56689,12 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
}
+
+ // Shrink if we remove an illegal type.
+ if (!TLI.isTypeLegal(Index.getValueType()) && TLI.isTypeLegal(NewVT)) {
+ Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
+ return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ }
}
}
diff --git a/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
new file mode 100644
index 0000000000000..faa83b0a20290
--- /dev/null
+++ b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
@@ -0,0 +1,68 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
+
+%struct.pt = type { float, float, float, i32 }
+%struct.res = type {<16 x float>, <16 x float>}
+
+define <16 x float> @test_gather_16f32_1(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
+; CHECK-LABEL: test_gather_16f32_1:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT: vmovdqu64 (%rsi), %zmm2
+; CHECK-NEXT: vpmovb2m %xmm0, %k1
+; CHECK-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm2, %zmm0
+; CHECK-NEXT: vpaddd %zmm0, %zmm0, %zmm0
+; CHECK-NEXT: vgatherdps (%rdi,%zmm0,8), %zmm1 {%k1}
+; CHECK-NEXT: vmovaps %zmm1, %zmm0
+; CHECK-NEXT: retq
+ %wide.load = load <16 x i32>, ptr %arr, align 4
+ %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+ %zext = zext <16 x i32> %and to <16 x i64>
+ %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext
+ %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ ret <16 x float> %res
+ }
+
+define <16 x float> @test_gather_16f32_2(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
+; CHECK-LABEL: test_gather_16f32_2:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT: vmovdqu64 (%rsi), %zmm2
+; CHECK-NEXT: vpmovb2m %xmm0, %k1
+; CHECK-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm2, %zmm0
+; CHECK-NEXT: vpaddd %zmm0, %zmm0, %zmm0
+; CHECK-NEXT: vgatherdps 4(%rdi,%zmm0,8), %zmm1 {%k1}
+; CHECK-NEXT: vmovaps %zmm1, %zmm0
+; CHECK-NEXT: retq
+ %wide.load = load <16 x i32>, ptr %arr, align 4
+ %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+ %zext = zext <16 x i32> %and to <16 x i64>
+ %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1
+ %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ ret <16 x float> %res
+ }
+
+define {<16 x float>, <16 x float>} @test_gather_16f32_3(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
+; CHECK-LABEL: test_gather_16f32_3:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT: vpmovb2m %xmm0, %k1
+; CHECK-NEXT: vmovdqu64 (%rsi), %zmm0
+; CHECK-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm0, %zmm0
+; CHECK-NEXT: kmovq %k1, %k2
+; CHECK-NEXT: vpaddd %zmm0, %zmm0, %zmm2
+; CHECK-NEXT: vmovaps %zmm1, %zmm0
+; CHECK-NEXT: vgatherdps (%rdi,%zmm2,8), %zmm0 {%k2}
+; CHECK-NEXT: vgatherdps 4(%rdi,%zmm2,8), %zmm1 {%k1}
+; CHECK-NEXT: retq
+ %wide.load = load <16 x i32>, ptr %arr, align 4
+ %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+ %zext = zext <16 x i32> %and to <16 x i64>
+ %ptrs1 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext
+ %res1 = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs1, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1
+ %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ %pair1 = insertvalue {<16 x float>, <16 x float>} undef, <16 x float> %res1, 0
+ %pair2 = insertvalue {<16 x float>, <16 x float>} %pair1, <16 x float> %res, 1
+ ret {<16 x float>, <16 x float>} %pair2
+ }
>From 7eb76638fcad21977183d045d2655267a8dddb98 Mon Sep 17 00:00:00 2001
From: Rohit Aggarwal <44664450+rohitaggarwal007 at users.noreply.github.com>
Date: Wed, 9 Apr 2025 16:53:45 +0530
Subject: [PATCH 3/4] Update gatherBaseIndexFix.ll
---
llvm/test/CodeGen/X86/gatherBaseIndexFix.ll | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
index a1d65a36410b4..a08ab5a936fa2 100644
--- a/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
+++ b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
@@ -67,4 +67,4 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_3(ptr %x, ptr %arr, <16 x
ret {<16 x float>, <16 x float>} %pair2
}
-declare <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr>, i32, <16 x i1>, <16 x float>)
\ No newline at end of file
+declare <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr>, i32, <16 x i1>, <16 x float>)
>From 4ed4a4d0318e63c930b286e903c4ff2123c088a7 Mon Sep 17 00:00:00 2001
From: Rohit Aggarwal <Rohit.Aggarwal at amd.com>
Date: Wed, 9 Apr 2025 17:00:44 +0530
Subject: [PATCH 4/4] squash! Changes
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 61e6d0734f402..afdaa485ccc88 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56643,6 +56643,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
// Attempt to move shifted index into the address scale, allows further
// index truncation below.
+ // TODO
if (Index.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Scale)) {
uint64_t ScaleAmt = Scale->getAsZExtVal();
if (auto MinShAmt = DAG.getValidMinimumShiftAmount(Index)) {
More information about the llvm-commits
mailing list