[llvm] 019f022 - [AArch64][SVE] Fold gather/scatter with 32bits when possible

Caroline Concatto via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 3 11:00:17 PST 2022


Author: Caroline Concatto
Date: 2022-02-03T18:58:37Z
New Revision: 019f0221d52dc14829b6011b5bccd9ba4c3849d8

URL: https://github.com/llvm/llvm-project/commit/019f0221d52dc14829b6011b5bccd9ba4c3849d8
DIFF: https://github.com/llvm/llvm-project/commit/019f0221d52dc14829b6011b5bccd9ba4c3849d8.diff

LOG: [AArch64][SVE] Fold gather/scatter with 32bits when possible

In AArch64ISelLowering.cpp this patch implements this fold:

GEP (%ptr, (splat(%offset) + stepvector(A)))
into GEP ((%ptr + %offset), stepvector(A))

The above transform simplifies the index operand so that it can be expressed
as i32 elements.
This allows using only one gather/scatter assembly instruction instead of two.

Patch by Paul Walker (@paulwalker-arm).

Depends on D118459

Differential Revision: https://reviews.llvm.org/D117900

Added: 
    llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b7daae45ced70..910e4252946e5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -889,6 +889,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(ISD::VECREDUCE_ADD);
   setTargetDAGCombine(ISD::STEP_VECTOR);
 
+  setTargetDAGCombine(ISD::MGATHER);
+  setTargetDAGCombine(ISD::MSCATTER);
+
   setTargetDAGCombine(ISD::FP_EXTEND);
 
   setTargetDAGCombine(ISD::GlobalAddress);
@@ -16358,6 +16361,93 @@ static SDValue performSTORECombine(SDNode *N,
   return SDValue();
 }
 
+// Analyse the specified address returning true if a more optimal addressing
+// mode is available. When returning true all parameters are updated to reflect
+// their recommended values.
+static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
+                                     SDValue &BasePtr, SDValue &Index,
+                                     ISD::MemIndexType &IndexType,
+                                     SelectionDAG &DAG) {
+  // Only consider element types that are pointer sized as smaller types can
+  // be easily promoted.
+  EVT IndexVT = Index.getValueType();
+  if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
+    return false;
+
+  int64_t Stride = 0;
+  SDLoc DL(N);
+  // Index = step(const) + splat(offset)
+  if (Index.getOpcode() == ISD::ADD &&
+      Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
+    SDValue StepVector = Index.getOperand(0);
+    if (auto Offset = DAG.getSplatValue(Index.getOperand(1))) {
+      Stride = cast<ConstantSDNode>(StepVector.getOperand(0))->getSExtValue();
+      Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
+      BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
+    }
+  }
+
+  // Return early because no supported pattern is found.
+  if (Stride == 0)
+    return false;
+
+  if (Stride < std::numeric_limits<int32_t>::min() ||
+      Stride > std::numeric_limits<int32_t>::max())
+    return false;
+
+  const auto &Subtarget =
+      static_cast<const AArch64Subtarget &>(DAG.getSubtarget());
+  unsigned MaxVScale =
+      Subtarget.getMaxSVEVectorSizeInBits() / AArch64::SVEBitsPerBlock;
+  int64_t LastElementOffset =
+      IndexVT.getVectorMinNumElements() * Stride * MaxVScale;
+
+  if (LastElementOffset < std::numeric_limits<int32_t>::min() ||
+      LastElementOffset > std::numeric_limits<int32_t>::max())
+    return false;
+
+  EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
+  Index = DAG.getNode(ISD::STEP_VECTOR, DL, NewIndexVT,
+                      DAG.getTargetConstant(Stride, DL, MVT::i32));
+  return true;
+}
+
+static SDValue performMaskedGatherScatterCombine(
+    SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
+  MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
+  assert(MGS && "Can only combine gather load or scatter store nodes");
+
+  if (!DCI.isBeforeLegalize())
+    return SDValue();
+
+  SDLoc DL(MGS);
+  SDValue Chain = MGS->getChain();
+  SDValue Scale = MGS->getScale();
+  SDValue Index = MGS->getIndex();
+  SDValue Mask = MGS->getMask();
+  SDValue BasePtr = MGS->getBasePtr();
+  ISD::MemIndexType IndexType = MGS->getIndexType();
+
+  if (!findMoreOptimalIndexType(MGS, BasePtr, Index, IndexType, DAG))
+    return SDValue();
+
+  // Here we catch such cases early and change MGATHER's IndexType to allow
+  // the use of an Index that's more legalisation friendly.
+  if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
+    SDValue PassThru = MGT->getPassThru();
+    SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
+    return DAG.getMaskedGather(
+        DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
+        Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
+  }
+  auto *MSC = cast<MaskedScatterSDNode>(MGS);
+  SDValue Data = MSC->getValue();
+  SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale};
+  return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL,
+                              Ops, MSC->getMemOperand(), IndexType,
+                              MSC->isTruncatingStore());
+}
+
 /// Target-specific DAG combine function for NEON load/store intrinsics
 /// to merge base address updates.
 static SDValue performNEONPostLDSTCombine(SDNode *N,
@@ -17820,6 +17910,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     break;
   case ISD::STORE:
     return performSTORECombine(N, DCI, DAG, Subtarget);
+  case ISD::MGATHER:
+  case ISD::MSCATTER:
+    return performMaskedGatherScatterCombine(N, DCI, DAG);
   case ISD::VECTOR_SPLICE:
     return performSVESpliceCombine(N, DAG);
   case ISD::FP_EXTEND:

diff  --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
new file mode 100644
index 0000000000000..262acd75ca426
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
@@ -0,0 +1,212 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64-linux-unknown | FileCheck %s
+
+
+; Ensure we use a "vscale x 4" wide scatter for the maximum supported offset.
+define void @scatter_i8_index_offset_maximum(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
+; CHECK-LABEL: scatter_i8_index_offset_maximum:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #33554431
+; CHECK-NEXT:    add x9, x0, x1
+; CHECK-NEXT:    index z1.s, #0, w8
+; CHECK-NEXT:    st1b { z0.s }, p0, [x9, z1.s, sxtw]
+; CHECK-NEXT:    ret
+  %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %t2 = insertelement <vscale x 4 x i64> undef, i64 33554431, i32 0
+  %t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %t4 = mul <vscale x 4 x i64> %t3, %step
+  %t5 = add <vscale x 4 x i64> %t1, %t4
+  %t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
+  call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
+  ret void
+}
+
+; Ensure we use a "vscale x 4" wide scatter for the minimum supported offset.
+define void @scatter_i16_index_offset_minimum(i16* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i16> %data) #0 {
+; CHECK-LABEL: scatter_i16_index_offset_minimum:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #-33554432
+; CHECK-NEXT:    add x9, x0, x1, lsl #1
+; CHECK-NEXT:    index z1.s, #0, w8
+; CHECK-NEXT:    st1h { z0.s }, p0, [x9, z1.s, sxtw #1]
+; CHECK-NEXT:    ret
+  %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %t2 = insertelement <vscale x 4 x i64> undef, i64 -33554432, i32 0
+  %t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %t4 = mul <vscale x 4 x i64> %t3, %step
+  %t5 = add <vscale x 4 x i64> %t1, %t4
+  %t6 = getelementptr i16, i16* %base, <vscale x 4 x i64> %t5
+  call void @llvm.masked.scatter.nxv4i16(<vscale x 4 x i16> %data, <vscale x 4 x i16*> %t6, i32 2, <vscale x 4 x i1> %pg)
+  ret void
+}
+
+; Ensure we use a "vscale x 4" gather for an offset in the limits of 32 bits.
+define <vscale x 4 x i8> @gather_i8_index_offset_8(i8* %base, i64 %offset, <vscale x 4 x i1> %pg) #0 {
+; CHECK-LABEL: gather_i8_index_offset_8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add x8, x0, x1
+; CHECK-NEXT:    index z0.s, #0, #1
+; CHECK-NEXT:    ld1sb { z0.s }, p0/z, [x8, z0.s, sxtw]
+; CHECK-NEXT:    ret
+  %splat.insert0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %splat0 = shufflevector <vscale x 4 x i64> %splat.insert0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %splat.insert1 = insertelement <vscale x 4 x i64> undef, i64 1, i32 0
+  %splat1 = shufflevector <vscale x 4 x i64> %splat.insert1, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %t1 = mul <vscale x 4 x i64> %splat1, %step
+  %t2 = add <vscale x 4 x i64> %splat0, %t1
+  %t3 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t2
+  %load = call <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*> %t3, i32 4, <vscale x 4 x i1> %pg, <vscale x 4 x i8> undef)
+   ret <vscale x 4 x i8> %load
+}
+
+;; Negative tests
+
+; Ensure we don't use a "vscale x 4" scatter. Cannot prove that variable stride
+; will not wrap when shrunk to be i32 based.
+define void @scatter_f16_index_offset_var(half* %base, i64 %offset, i64 %scale, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
+; CHECK-LABEL: scatter_f16_index_offset_var:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    index z1.d, #0, #1
+; CHECK-NEXT:    mov z3.d, x1
+; CHECK-NEXT:    mov z2.d, z1.d
+; CHECK-NEXT:    mov z4.d, z3.d
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    incd z2.d
+; CHECK-NEXT:    mla z3.d, p1/m, z1.d, z3.d
+; CHECK-NEXT:    mla z4.d, p1/m, z2.d, z4.d
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    uunpklo z1.d, z0.s
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    uunpkhi z0.d, z0.s
+; CHECK-NEXT:    st1h { z1.d }, p1, [x0, z3.d, lsl #1]
+; CHECK-NEXT:    st1h { z0.d }, p0, [x0, z4.d, lsl #1]
+; CHECK-NEXT:    ret
+  %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %t2 = insertelement <vscale x 4 x i64> undef, i64 %scale, i32 0
+  %t3 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %t4 = mul <vscale x 4 x i64> %t3, %step
+  %t5 = add <vscale x 4 x i64> %t1, %t4
+  %t6 = getelementptr half, half* %base, <vscale x 4 x i64> %t5
+  call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %t6, i32 2, <vscale x 4 x i1> %pg)
+  ret void
+}
+
+; Ensure we don't use a "vscale x 4" wide scatter when the offset is too big.
+define void @scatter_i8_index_offset_maximum_plus_one(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
+; CHECK-LABEL: scatter_i8_index_offset_maximum_plus_one:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdvl x8, #1
+; CHECK-NEXT:    mov w9, #67108864
+; CHECK-NEXT:    lsr x8, x8, #4
+; CHECK-NEXT:    mov z1.d, x1
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mul x8, x8, x9
+; CHECK-NEXT:    mov w9, #33554432
+; CHECK-NEXT:    index z2.d, #0, x9
+; CHECK-NEXT:    mov z3.d, x8
+; CHECK-NEXT:    add z3.d, z2.d, z3.d
+; CHECK-NEXT:    add z2.d, z2.d, z1.d
+; CHECK-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEXT:    uunpklo z3.d, z0.s
+; CHECK-NEXT:    uunpkhi z0.d, z0.s
+; CHECK-NEXT:    st1b { z3.d }, p1, [x0, z2.d]
+; CHECK-NEXT:    st1b { z0.d }, p0, [x0, z1.d]
+; CHECK-NEXT:    ret
+  %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %t2 = insertelement <vscale x 4 x i64> undef, i64 33554432, i32 0
+  %t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %t4 = mul <vscale x 4 x i64> %t3, %step
+  %t5 = add <vscale x 4 x i64> %t1, %t4
+  %t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
+  call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
+  ret void
+}
+
+; Ensure we don't use a "vscale x 4" wide scatter when the offset is too small.
+define void @scatter_i8_index_offset_minimum_minus_one(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
+; CHECK-LABEL: scatter_i8_index_offset_minimum_minus_one:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdvl x8, #1
+; CHECK-NEXT:    mov x9, #-2
+; CHECK-NEXT:    lsr x8, x8, #4
+; CHECK-NEXT:    movk x9, #64511, lsl #16
+; CHECK-NEXT:    mov z1.d, x1
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mul x8, x8, x9
+; CHECK-NEXT:    mov x9, #-33554433
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    index z2.d, #0, x9
+; CHECK-NEXT:    mov z3.d, x8
+; CHECK-NEXT:    add z3.d, z2.d, z3.d
+; CHECK-NEXT:    add z2.d, z2.d, z1.d
+; CHECK-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEXT:    uunpklo z3.d, z0.s
+; CHECK-NEXT:    uunpkhi z0.d, z0.s
+; CHECK-NEXT:    st1b { z3.d }, p1, [x0, z2.d]
+; CHECK-NEXT:    st1b { z0.d }, p0, [x0, z1.d]
+; CHECK-NEXT:    ret
+  %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %t2 = insertelement <vscale x 4 x i64> undef, i64 -33554433, i32 0
+  %t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %t4 = mul <vscale x 4 x i64> %t3, %step
+  %t5 = add <vscale x 4 x i64> %t1, %t4
+  %t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
+  call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
+  ret void
+}
+
+; Ensure we don't use a "vscale x 4" wide scatter when the stride is too big .
+define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
+; CHECK-LABEL: scatter_i8_index_stride_too_big:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdvl x8, #1
+; CHECK-NEXT:    mov x9, #-9223372036854775808
+; CHECK-NEXT:    lsr x8, x8, #4
+; CHECK-NEXT:    mov z1.d, x1
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mul x8, x8, x9
+; CHECK-NEXT:    mov x9, #4611686018427387904
+; CHECK-NEXT:    index z2.d, #0, x9
+; CHECK-NEXT:    mov z3.d, x8
+; CHECK-NEXT:    add z3.d, z2.d, z3.d
+; CHECK-NEXT:    add z2.d, z2.d, z1.d
+; CHECK-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEXT:    uunpklo z3.d, z0.s
+; CHECK-NEXT:    uunpkhi z0.d, z0.s
+; CHECK-NEXT:    st1b { z3.d }, p1, [x0, z2.d]
+; CHECK-NEXT:    st1b { z0.d }, p0, [x0, z1.d]
+; CHECK-NEXT:    ret
+  %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %t2 = insertelement <vscale x 4 x i64> undef, i64 4611686018427387904, i32 0
+  %t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %t4 = mul <vscale x 4 x i64> %t3, %step
+  %t5 = add <vscale x 4 x i64> %t1, %t4
+  %t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
+  call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
+  ret void
+}
+
+
+attributes #0 = { "target-features"="+sve" vscale_range(1, 16) }
+
+
+declare <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*>, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
+declare void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8>, <vscale x 4 x i8*>, i32, <vscale x 4 x i1>)
+declare void @llvm.masked.scatter.nxv4i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32, <vscale x 4 x i1>)
+declare void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half>, <vscale x 4 x half*>, i32, <vscale x 4 x i1>)
+declare <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()


        


More information about the llvm-commits mailing list