[llvm] 201e368 - [AArch64][SVE] Handle more cases in findMoreOptimalIndexType.
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 28 04:16:42 PST 2022
Author: Sander de Smalen
Date: 2022-02-28T12:13:52Z
New Revision: 201e3686ab4533213e6d72d192d4493002ffa679
URL: https://github.com/llvm/llvm-project/commit/201e3686ab4533213e6d72d192d4493002ffa679
DIFF: https://github.com/llvm/llvm-project/commit/201e3686ab4533213e6d72d192d4493002ffa679.diff
LOG: [AArch64][SVE] Handle more cases in findMoreOptimalIndexType.
This patch addresses @paulwalker-arm's comment on D117900 to
only update/write the by-ref operands iff the function returns
true. It also handles a few more cases where a series of added
offsets can be folded into the base pointer, rather than just looking
at a single offset.
Reviewed By: paulwalker-arm
Differential Revision: https://reviews.llvm.org/D119728
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9754df1a6a641..daadcc048b163 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16476,55 +16476,90 @@ static SDValue performSTORECombine(SDNode *N,
return SDValue();
}
-// Analyse the specified address returning true if a more optimal addressing
-// mode is available. When returning true all parameters are updated to reflect
-// their recommended values.
-static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
- SDValue &BasePtr, SDValue &Index,
- ISD::MemIndexType &IndexType,
- SelectionDAG &DAG) {
- // Only consider element types that are pointer sized as smaller types can
- // be easily promoted.
+/// \return true if part of the index was folded into the Base.
+static bool foldIndexIntoBase(SDValue &BasePtr, SDValue &Index, SDValue Scale,
+ SDLoc DL, SelectionDAG &DAG) {
+ // This function assumes a vector of i64 indices.
EVT IndexVT = Index.getValueType();
- if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
+ if (!IndexVT.isVector() || IndexVT.getVectorElementType() != MVT::i64)
return false;
- int64_t Stride = 0;
- SDLoc DL(N);
- // Index = step(const) + splat(offset)
- if (Index.getOpcode() == ISD::ADD &&
- Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
- SDValue StepVector = Index.getOperand(0);
+ // Simplify:
+ // BasePtr = Ptr
+ // Index = X + splat(Offset)
+ // ->
+ // BasePtr = Ptr + Offset * scale.
+ // Index = X
+ if (Index.getOpcode() == ISD::ADD) {
if (auto Offset = DAG.getSplatValue(Index.getOperand(1))) {
- Stride = cast<ConstantSDNode>(StepVector.getOperand(0))->getSExtValue();
- Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
+ Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale);
BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
+ Index = Index.getOperand(0);
+ return true;
}
}
- // Index = shl((step(const) + splat(offset))), splat(shift))
+ // Simplify:
+ // BasePtr = Ptr
+ // Index = (X + splat(Offset)) << splat(Shift)
+ // ->
+ // BasePtr = Ptr + (Offset << Shift) * scale)
+ // Index = X << splat(shift)
if (Index.getOpcode() == ISD::SHL &&
- Index.getOperand(0).getOpcode() == ISD::ADD &&
- Index.getOperand(0).getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
+ Index.getOperand(0).getOpcode() == ISD::ADD) {
SDValue Add = Index.getOperand(0);
SDValue ShiftOp = Index.getOperand(1);
- SDValue StepOp = Add.getOperand(0);
SDValue OffsetOp = Add.getOperand(1);
- if (auto *Shift =
- dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(ShiftOp)))
+ if (auto Shift = DAG.getSplatValue(ShiftOp))
if (auto Offset = DAG.getSplatValue(OffsetOp)) {
- int64_t Step =
- cast<ConstantSDNode>(StepOp.getOperand(0))->getSExtValue();
- // Stride does not scale explicitly by 'Scale', because it happens in
- // the gather/scatter addressing mode.
- Stride = Step << Shift->getSExtValue();
- // BasePtr = BasePtr + ((Offset * Scale) << Shift)
- Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
- Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, SDValue(Shift, 0));
+ Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, Shift);
+ Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale);
BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
+ Index = DAG.getNode(ISD::SHL, DL, Index.getValueType(),
+ Add.getOperand(0), ShiftOp);
+ return true;
}
}
+ return false;
+}
+
+// Analyse the specified address returning true if a more optimal addressing
+// mode is available. When returning true all parameters are updated to reflect
+// their recommended values.
+static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
+ SDValue &BasePtr, SDValue &Index,
+ SelectionDAG &DAG) {
+ // Only consider element types that are pointer sized as smaller types can
+ // be easily promoted.
+ EVT IndexVT = Index.getValueType();
+ if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
+ return false;
+
+ // Try to iteratively fold parts of the index into the base pointer to
+ // simplify the index as much as possible.
+ SDValue NewBasePtr = BasePtr, NewIndex = Index;
+ while (foldIndexIntoBase(NewBasePtr, NewIndex, N->getScale(), SDLoc(N), DAG))
+ ;
+
+ // Match:
+ // Index = step(const)
+ int64_t Stride = 0;
+ if (NewIndex.getOpcode() == ISD::STEP_VECTOR)
+ Stride = cast<ConstantSDNode>(NewIndex.getOperand(0))->getSExtValue();
+
+ // Match:
+ // Index = step(const) << shift(const)
+ else if (NewIndex.getOpcode() == ISD::SHL &&
+ NewIndex.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
+ SDValue RHS = NewIndex.getOperand(1);
+ if (auto *Shift =
+ dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(RHS))) {
+ int64_t Step = (int64_t)NewIndex.getOperand(0).getConstantOperandVal(1);
+ Stride = Step << Shift->getZExtValue();
+ }
+ }
+
// Return early because no supported pattern is found.
if (Stride == 0)
return false;
@@ -16545,8 +16580,11 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
return false;
EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
- Index = DAG.getNode(ISD::STEP_VECTOR, DL, NewIndexVT,
- DAG.getTargetConstant(Stride, DL, MVT::i32));
+ // Stride does not scale explicitly by 'Scale', because it happens in
+ // the gather/scatter addressing mode.
+ Index = DAG.getNode(ISD::STEP_VECTOR, SDLoc(N), NewIndexVT,
+ DAG.getTargetConstant(Stride, SDLoc(N), MVT::i32));
+ BasePtr = NewBasePtr;
return true;
}
@@ -16566,7 +16604,7 @@ static SDValue performMaskedGatherScatterCombine(
SDValue BasePtr = MGS->getBasePtr();
ISD::MemIndexType IndexType = MGS->getIndexType();
- if (!findMoreOptimalIndexType(MGS, BasePtr, Index, IndexType, DAG))
+ if (!findMoreOptimalIndexType(MGS, BasePtr, Index, DAG))
return SDValue();
// Here we catch such cases early and change MGATHER's IndexType to allow
diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
index dd5a1264f8700..e894291845d41 100644
--- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
@@ -283,7 +283,54 @@ define void @scatter_f16_index_offset_8([8 x half]* %base, i64 %offset, <vscale
ret void
}
+; stepvector is hidden further behind GEP and two adds.
+define void @scatter_f16_index_add_add([8 x half]* %base, i64 %offset, i64 %offset2, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
+; CHECK-LABEL: scatter_f16_index_add_add:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov w8, #16
+; CHECK-NEXT: add x9, x0, x2, lsl #4
+; CHECK-NEXT: add x9, x9, x1, lsl #4
+; CHECK-NEXT: index z1.s, #0, w8
+; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw]
+; CHECK-NEXT: ret
+ %splat.offset.ins = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+ %splat.offset = shufflevector <vscale x 4 x i64> %splat.offset.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %splat.offset2.ins = insertelement <vscale x 4 x i64> undef, i64 %offset2, i32 0
+ %splat.offset2 = shufflevector <vscale x 4 x i64> %splat.offset2.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %add1 = add <vscale x 4 x i64> %splat.offset, %step
+ %add2 = add <vscale x 4 x i64> %add1, %splat.offset2
+ %gep = getelementptr [8 x half], [8 x half]* %base, <vscale x 4 x i64> %add2
+ %gep.bc = bitcast <vscale x 4 x [8 x half]*> %gep to <vscale x 4 x half*>
+ call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %gep.bc, i32 2, <vscale x 4 x i1> %pg)
+ ret void
+}
+; stepvector is hidden further behind GEP two adds and a shift.
+define void @scatter_f16_index_add_add_mul([8 x half]* %base, i64 %offset, i64 %offset2, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
+; CHECK-LABEL: scatter_f16_index_add_add_mul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov w8, #128
+; CHECK-NEXT: add x9, x0, x2, lsl #7
+; CHECK-NEXT: add x9, x9, x1, lsl #7
+; CHECK-NEXT: index z1.s, #0, w8
+; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw]
+; CHECK-NEXT: ret
+ %splat.offset.ins = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+ %splat.offset = shufflevector <vscale x 4 x i64> %splat.offset.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %splat.offset2.ins = insertelement <vscale x 4 x i64> undef, i64 %offset2, i32 0
+ %splat.offset2 = shufflevector <vscale x 4 x i64> %splat.offset2.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %add1 = add <vscale x 4 x i64> %splat.offset, %step
+ %add2 = add <vscale x 4 x i64> %add1, %splat.offset2
+ %splat.const8.ins = insertelement <vscale x 4 x i64> undef, i64 8, i32 0
+ %splat.const8 = shufflevector <vscale x 4 x i64> %splat.const8.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %mul = mul <vscale x 4 x i64> %add2, %splat.const8
+ %gep = getelementptr [8 x half], [8 x half]* %base, <vscale x 4 x i64> %mul
+ %gep.bc = bitcast <vscale x 4 x [8 x half]*> %gep to <vscale x 4 x half*>
+ call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %gep.bc, i32 2, <vscale x 4 x i1> %pg)
+ ret void
+}
attributes #0 = { "target-features"="+sve" vscale_range(1, 16) }
declare <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x float*>, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
More information about the llvm-commits
mailing list