[llvm] 019f022 - [AArch64][SVE] Fold gather/scatter with 32bits when possible
Caroline Concatto via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 3 11:00:17 PST 2022
Author: Caroline Concatto
Date: 2022-02-03T18:58:37Z
New Revision: 019f0221d52dc14829b6011b5bccd9ba4c3849d8
URL: https://github.com/llvm/llvm-project/commit/019f0221d52dc14829b6011b5bccd9ba4c3849d8
DIFF: https://github.com/llvm/llvm-project/commit/019f0221d52dc14829b6011b5bccd9ba4c3849d8.diff
LOG: [AArch64][SVE] Fold gather/scatter with 32bits when possible
In AArch64ISelLowering.cpp this patch implements this fold:
GEP (%ptr, (splat(%offset) + stepvector(A)))
into GEP ((%ptr + %offset), stepvector(A))
The above transform simplifies the index operand so that it can be expressed
as i32 elements.
This allows using only one gather/scatter assembly instruction instead of two.
Patch by Paul Walker (@paulwalker-arm).
Depends on D118459
Differential Revision: https://reviews.llvm.org/D117900
Added:
llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b7daae45ced70..910e4252946e5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -889,6 +889,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::VECREDUCE_ADD);
setTargetDAGCombine(ISD::STEP_VECTOR);
+ setTargetDAGCombine(ISD::MGATHER);
+ setTargetDAGCombine(ISD::MSCATTER);
+
setTargetDAGCombine(ISD::FP_EXTEND);
setTargetDAGCombine(ISD::GlobalAddress);
@@ -16358,6 +16361,93 @@ static SDValue performSTORECombine(SDNode *N,
return SDValue();
}
+// Analyse the specified address returning true if a more optimal addressing
+// mode is available. When returning true all parameters are updated to reflect
+// their recommended values.
+static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
+ SDValue &BasePtr, SDValue &Index,
+ ISD::MemIndexType &IndexType,
+ SelectionDAG &DAG) {
+ // Only consider element types that are pointer sized as smaller types can
+ // be easily promoted.
+ EVT IndexVT = Index.getValueType();
+ if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
+ return false;
+
+ int64_t Stride = 0;
+ SDLoc DL(N);
+ // Index = step(const) + splat(offset)
+ if (Index.getOpcode() == ISD::ADD &&
+ Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
+ SDValue StepVector = Index.getOperand(0);
+ if (auto Offset = DAG.getSplatValue(Index.getOperand(1))) {
+ Stride = cast<ConstantSDNode>(StepVector.getOperand(0))->getSExtValue();
+ Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
+ BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
+ }
+ }
+
+ // Return early because no supported pattern is found.
+ if (Stride == 0)
+ return false;
+
+ if (Stride < std::numeric_limits<int32_t>::min() ||
+ Stride > std::numeric_limits<int32_t>::max())
+ return false;
+
+ const auto &Subtarget =
+ static_cast<const AArch64Subtarget &>(DAG.getSubtarget());
+ unsigned MaxVScale =
+ Subtarget.getMaxSVEVectorSizeInBits() / AArch64::SVEBitsPerBlock;
+ int64_t LastElementOffset =
+ IndexVT.getVectorMinNumElements() * Stride * MaxVScale;
+
+ if (LastElementOffset < std::numeric_limits<int32_t>::min() ||
+ LastElementOffset > std::numeric_limits<int32_t>::max())
+ return false;
+
+ EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
+ Index = DAG.getNode(ISD::STEP_VECTOR, DL, NewIndexVT,
+ DAG.getTargetConstant(Stride, DL, MVT::i32));
+ return true;
+}
+
+static SDValue performMaskedGatherScatterCombine(
+ SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
+ MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
+ assert(MGS && "Can only combine gather load or scatter store nodes");
+
+ if (!DCI.isBeforeLegalize())
+ return SDValue();
+
+ SDLoc DL(MGS);
+ SDValue Chain = MGS->getChain();
+ SDValue Scale = MGS->getScale();
+ SDValue Index = MGS->getIndex();
+ SDValue Mask = MGS->getMask();
+ SDValue BasePtr = MGS->getBasePtr();
+ ISD::MemIndexType IndexType = MGS->getIndexType();
+
+ if (!findMoreOptimalIndexType(MGS, BasePtr, Index, IndexType, DAG))
+ return SDValue();
+
+ // Here we catch such cases early and change MGATHER's IndexType to allow
+ // the use of an Index that's more legalisation friendly.
+ if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
+ SDValue PassThru = MGT->getPassThru();
+ SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
+ return DAG.getMaskedGather(
+ DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
+ Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
+ }
+ auto *MSC = cast<MaskedScatterSDNode>(MGS);
+ SDValue Data = MSC->getValue();
+ SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale};
+ return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL,
+ Ops, MSC->getMemOperand(), IndexType,
+ MSC->isTruncatingStore());
+}
+
/// Target-specific DAG combine function for NEON load/store intrinsics
/// to merge base address updates.
static SDValue performNEONPostLDSTCombine(SDNode *N,
@@ -17820,6 +17910,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
break;
case ISD::STORE:
return performSTORECombine(N, DCI, DAG, Subtarget);
+ case ISD::MGATHER:
+ case ISD::MSCATTER:
+ return performMaskedGatherScatterCombine(N, DCI, DAG);
case ISD::VECTOR_SPLICE:
return performSVESpliceCombine(N, DAG);
case ISD::FP_EXTEND:
diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
new file mode 100644
index 0000000000000..262acd75ca426
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
@@ -0,0 +1,212 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64-linux-unknown | FileCheck %s
+
+
+; Ensure we use a "vscale x 4" wide scatter for the maximum supported offset.
+define void @scatter_i8_index_offset_maximum(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
+; CHECK-LABEL: scatter_i8_index_offset_maximum:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov w8, #33554431
+; CHECK-NEXT: add x9, x0, x1
+; CHECK-NEXT: index z1.s, #0, w8
+; CHECK-NEXT: st1b { z0.s }, p0, [x9, z1.s, sxtw]
+; CHECK-NEXT: ret
+ %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+ %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %t2 = insertelement <vscale x 4 x i64> undef, i64 33554431, i32 0
+ %t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %t4 = mul <vscale x 4 x i64> %t3, %step
+ %t5 = add <vscale x 4 x i64> %t1, %t4
+ %t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
+ call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
+ ret void
+}
+
+; Ensure we use a "vscale x 4" wide scatter for the minimum supported offset.
+define void @scatter_i16_index_offset_minimum(i16* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i16> %data) #0 {
+; CHECK-LABEL: scatter_i16_index_offset_minimum:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov w8, #-33554432
+; CHECK-NEXT: add x9, x0, x1, lsl #1
+; CHECK-NEXT: index z1.s, #0, w8
+; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw #1]
+; CHECK-NEXT: ret
+ %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+ %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %t2 = insertelement <vscale x 4 x i64> undef, i64 -33554432, i32 0
+ %t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %t4 = mul <vscale x 4 x i64> %t3, %step
+ %t5 = add <vscale x 4 x i64> %t1, %t4
+ %t6 = getelementptr i16, i16* %base, <vscale x 4 x i64> %t5
+ call void @llvm.masked.scatter.nxv4i16(<vscale x 4 x i16> %data, <vscale x 4 x i16*> %t6, i32 2, <vscale x 4 x i1> %pg)
+ ret void
+}
+
+; Ensure we use a "vscale x 4" gather for an offset in the limits of 32 bits.
+define <vscale x 4 x i8> @gather_i8_index_offset_8(i8* %base, i64 %offset, <vscale x 4 x i1> %pg) #0 {
+; CHECK-LABEL: gather_i8_index_offset_8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: add x8, x0, x1
+; CHECK-NEXT: index z0.s, #0, #1
+; CHECK-NEXT: ld1sb { z0.s }, p0/z, [x8, z0.s, sxtw]
+; CHECK-NEXT: ret
+ %splat.insert0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+ %splat0 = shufflevector <vscale x 4 x i64> %splat.insert0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %splat.insert1 = insertelement <vscale x 4 x i64> undef, i64 1, i32 0
+ %splat1 = shufflevector <vscale x 4 x i64> %splat.insert1, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %t1 = mul <vscale x 4 x i64> %splat1, %step
+ %t2 = add <vscale x 4 x i64> %splat0, %t1
+ %t3 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t2
+ %load = call <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*> %t3, i32 4, <vscale x 4 x i1> %pg, <vscale x 4 x i8> undef)
+ ret <vscale x 4 x i8> %load
+}
+
+;; Negative tests
+
+; Ensure we don't use a "vscale x 4" scatter. Cannot prove that variable stride
+; will not wrap when shrunk to be i32 based.
+define void @scatter_f16_index_offset_var(half* %base, i64 %offset, i64 %scale, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
+; CHECK-LABEL: scatter_f16_index_offset_var:
+; CHECK: // %bb.0:
+; CHECK-NEXT: index z1.d, #0, #1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: mov z2.d, z1.d
+; CHECK-NEXT: mov z4.d, z3.d
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: incd z2.d
+; CHECK-NEXT: mla z3.d, p1/m, z1.d, z3.d
+; CHECK-NEXT: mla z4.d, p1/m, z2.d, z4.d
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: uunpklo z1.d, z0.s
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: uunpkhi z0.d, z0.s
+; CHECK-NEXT: st1h { z1.d }, p1, [x0, z3.d, lsl #1]
+; CHECK-NEXT: st1h { z0.d }, p0, [x0, z4.d, lsl #1]
+; CHECK-NEXT: ret
+ %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+ %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %t2 = insertelement <vscale x 4 x i64> undef, i64 %scale, i32 0
+ %t3 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %t4 = mul <vscale x 4 x i64> %t3, %step
+ %t5 = add <vscale x 4 x i64> %t1, %t4
+ %t6 = getelementptr half, half* %base, <vscale x 4 x i64> %t5
+ call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %t6, i32 2, <vscale x 4 x i1> %pg)
+ ret void
+}
+
+; Ensure we don't use a "vscale x 4" wide scatter when the offset is too big.
+define void @scatter_i8_index_offset_maximum_plus_one(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
+; CHECK-LABEL: scatter_i8_index_offset_maximum_plus_one:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdvl x8, #1
+; CHECK-NEXT: mov w9, #67108864
+; CHECK-NEXT: lsr x8, x8, #4
+; CHECK-NEXT: mov z1.d, x1
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: mul x8, x8, x9
+; CHECK-NEXT: mov w9, #33554432
+; CHECK-NEXT: index z2.d, #0, x9
+; CHECK-NEXT: mov z3.d, x8
+; CHECK-NEXT: add z3.d, z2.d, z3.d
+; CHECK-NEXT: add z2.d, z2.d, z1.d
+; CHECK-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEXT: uunpklo z3.d, z0.s
+; CHECK-NEXT: uunpkhi z0.d, z0.s
+; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d]
+; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d]
+; CHECK-NEXT: ret
+ %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+ %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %t2 = insertelement <vscale x 4 x i64> undef, i64 33554432, i32 0
+ %t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %t4 = mul <vscale x 4 x i64> %t3, %step
+ %t5 = add <vscale x 4 x i64> %t1, %t4
+ %t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
+ call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
+ ret void
+}
+
+; Ensure we don't use a "vscale x 4" wide scatter when the offset is too small.
+define void @scatter_i8_index_offset_minimum_minus_one(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
+; CHECK-LABEL: scatter_i8_index_offset_minimum_minus_one:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdvl x8, #1
+; CHECK-NEXT: mov x9, #-2
+; CHECK-NEXT: lsr x8, x8, #4
+; CHECK-NEXT: movk x9, #64511, lsl #16
+; CHECK-NEXT: mov z1.d, x1
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: mul x8, x8, x9
+; CHECK-NEXT: mov x9, #-33554433
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: index z2.d, #0, x9
+; CHECK-NEXT: mov z3.d, x8
+; CHECK-NEXT: add z3.d, z2.d, z3.d
+; CHECK-NEXT: add z2.d, z2.d, z1.d
+; CHECK-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEXT: uunpklo z3.d, z0.s
+; CHECK-NEXT: uunpkhi z0.d, z0.s
+; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d]
+; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d]
+; CHECK-NEXT: ret
+ %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+ %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %t2 = insertelement <vscale x 4 x i64> undef, i64 -33554433, i32 0
+ %t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %t4 = mul <vscale x 4 x i64> %t3, %step
+ %t5 = add <vscale x 4 x i64> %t1, %t4
+ %t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
+ call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
+ ret void
+}
+
+; Ensure we don't use a "vscale x 4" wide scatter when the stride is too big .
+define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
+; CHECK-LABEL: scatter_i8_index_stride_too_big:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdvl x8, #1
+; CHECK-NEXT: mov x9, #-9223372036854775808
+; CHECK-NEXT: lsr x8, x8, #4
+; CHECK-NEXT: mov z1.d, x1
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: mul x8, x8, x9
+; CHECK-NEXT: mov x9, #4611686018427387904
+; CHECK-NEXT: index z2.d, #0, x9
+; CHECK-NEXT: mov z3.d, x8
+; CHECK-NEXT: add z3.d, z2.d, z3.d
+; CHECK-NEXT: add z2.d, z2.d, z1.d
+; CHECK-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEXT: uunpklo z3.d, z0.s
+; CHECK-NEXT: uunpkhi z0.d, z0.s
+; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d]
+; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d]
+; CHECK-NEXT: ret
+ %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+ %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %t2 = insertelement <vscale x 4 x i64> undef, i64 4611686018427387904, i32 0
+ %t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+ %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %t4 = mul <vscale x 4 x i64> %t3, %step
+ %t5 = add <vscale x 4 x i64> %t1, %t4
+ %t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
+ call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
+ ret void
+}
+
+
+attributes #0 = { "target-features"="+sve" vscale_range(1, 16) }
+
+
+declare <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*>, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
+declare void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8>, <vscale x 4 x i8*>, i32, <vscale x 4 x i1>)
+declare void @llvm.masked.scatter.nxv4i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32, <vscale x 4 x i1>)
+declare void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half>, <vscale x 4 x half*>, i32, <vscale x 4 x i1>)
+declare <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
More information about the llvm-commits
mailing list