[llvm] 961e954 - [AArch64][SVE] Add more folds to make use of gather/scatter with 32-bit indices

Caroline Concatto via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 3 11:23:21 PST 2022


Author: Caroline Concatto
Date: 2022-02-03T19:18:30Z
New Revision: 961e954af592771f1323e41a049d25f433eaabb6

URL: https://github.com/llvm/llvm-project/commit/961e954af592771f1323e41a049d25f433eaabb6
DIFF: https://github.com/llvm/llvm-project/commit/961e954af592771f1323e41a049d25f433eaabb6.diff

LOG: [AArch64][SVE] Add more folds to make use of gather/scatter with 32-bit indices

In AArch64ISelLowering.cpp this patch implements this fold:

1) GEP (%ptr, SHL ((stepvector(A) + splat(%offset))) << splat(B)))
into GEP (%ptr + (%offset << B), step_vector (A << B))

The above transform simplifies the index operand so that it can be expressed
as i32 elements.
This allows using only one gather/scatter assembly instruction instead of two.

Patch by Paul Walker (@paulwalker-arm).

Depends on D117900

Differential Revision: https://reviews.llvm.org/D118345

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 910e4252946e..0a1b2628a315 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16387,6 +16387,29 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
     }
   }
 
+  // Index = shl((step(const) + splat(offset))), splat(shift))
+  if (Index.getOpcode() == ISD::SHL &&
+      Index.getOperand(0).getOpcode() == ISD::ADD &&
+      Index.getOperand(0).getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
+    SDValue Add = Index.getOperand(0);
+    SDValue ShiftOp = Index.getOperand(1);
+    SDValue StepOp = Add.getOperand(0);
+    SDValue OffsetOp = Add.getOperand(1);
+    if (auto *Shift =
+            dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(ShiftOp)))
+      if (auto Offset = DAG.getSplatValue(OffsetOp)) {
+        int64_t Step =
+            cast<ConstantSDNode>(StepOp.getOperand(0))->getSExtValue();
+        // Stride does not scale explicitly by 'Scale', because it happens in
+        // the gather/scatter addressing mode.
+        Stride = Step << Shift->getSExtValue();
+        // BasePtr = BasePtr + ((Offset * Scale) << Shift)
+        Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
+        Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, SDValue(Shift, 0));
+        BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
+      }
+  }
+
   // Return early because no supported pattern is found.
   if (Stride == 0)
     return false;

diff  --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
index 262acd75ca42..7da424005221 100644
--- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
@@ -201,9 +201,92 @@ define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, <vscale x 4
   ret void
 }
 
+; Ensure the resulting load is "vscale x 4" wide, despite the offset giving the
+; impression the gather must be split due to it's <vscale x 4 x i64> offset.
+; gather_f32(base, index(offset, 8 * sizeof(float))
+define <vscale x 4 x i8> @gather_8i8_index_offset_8([8 x i8]* %base, i64 %offset, <vscale x 4 x i1> %pg) #0 {
+; CHECK-LABEL: gather_8i8_index_offset_8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add x8, x0, x1, lsl #3
+; CHECK-NEXT:    index z0.s, #0, #8
+; CHECK-NEXT:    ld1sb { z0.s }, p0/z, [x8, z0.s, sxtw]
+; CHECK-NEXT:    ret
+  %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %t2 = add <vscale x 4 x i64> %t1, %step
+  %t3 = getelementptr [8 x i8], [8 x i8]* %base, <vscale x 4 x i64> %t2
+  %t4 = bitcast <vscale x 4 x [8 x i8]*> %t3 to <vscale x 4 x i8*>
+  %load = call <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*> %t4, i32 4, <vscale x 4 x i1> %pg, <vscale x 4 x i8> undef)
+  ret <vscale x 4 x i8> %load
+}
+
+; Ensure the resulting load is "vscale x 4" wide, despite the offset giving the
+; impression the gather must be split due to it's <vscale x 4 x i64> offset.
+; gather_f32(base, index(offset, 8 * sizeof(float))
+define <vscale x 4 x float> @gather_f32_index_offset_8([8 x float]* %base, i64 %offset, <vscale x 4 x i1> %pg) #0 {
+; CHECK-LABEL: gather_f32_index_offset_8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #32
+; CHECK-NEXT:    add x9, x0, x1, lsl #5
+; CHECK-NEXT:    index z0.s, #0, w8
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x9, z0.s, sxtw]
+; CHECK-NEXT:    ret
+  %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %t2 = add <vscale x 4 x i64> %t1, %step
+  %t3 = getelementptr [8 x float], [8 x float]* %base, <vscale x 4 x i64> %t2
+  %t4 = bitcast <vscale x 4 x [8 x float]*> %t3 to <vscale x 4 x float*>
+  %load = call <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x float*> %t4, i32 4, <vscale x 4 x i1> %pg, <vscale x 4 x float> undef)
+  ret <vscale x 4 x float> %load
+}
+
+; Ensure the resulting store is "vscale x 4" wide, despite the offset giving the
+; impression the scatter must be split due to it's <vscale x 4 x i64> offset.
+; scatter_f16(base, index(offset, 8 * sizeof(i8))
+define void @scatter_i8_index_offset_8([8 x i8]* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
+; CHECK-LABEL: scatter_i8_index_offset_8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add x8, x0, x1, lsl #3
+; CHECK-NEXT:    index z1.s, #0, #8
+; CHECK-NEXT:    st1b { z0.s }, p0, [x8, z1.s, sxtw]
+; CHECK-NEXT:    ret
+  %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %t2 = add <vscale x 4 x i64> %t1, %step
+  %t3 = getelementptr [8 x i8], [8 x i8]* %base, <vscale x 4 x i64> %t2
+  %t4 = bitcast <vscale x 4 x [8 x i8]*> %t3 to <vscale x 4 x i8*>
+  call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t4, i32 2, <vscale x 4 x i1> %pg)
+  ret void
+}
+
+; Ensure the resulting store is "vscale x 4" wide, despite the offset giving the
+; impression the scatter must be split due to it's <vscale x 4 x i64> offset.
+; scatter_f16(base, index(offset, 8 * sizeof(half))
+define void @scatter_f16_index_offset_8([8 x half]* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
+; CHECK-LABEL: scatter_f16_index_offset_8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #16
+; CHECK-NEXT:    add x9, x0, x1, lsl #4
+; CHECK-NEXT:    index z1.s, #0, w8
+; CHECK-NEXT:    st1h { z0.s }, p0, [x9, z1.s, sxtw]
+; CHECK-NEXT:    ret
+  %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %t2 = add <vscale x 4 x i64> %t1, %step
+  %t3 = getelementptr [8 x half], [8 x half]* %base, <vscale x 4 x i64> %t2
+  %t4 = bitcast <vscale x 4 x [8 x half]*> %t3 to <vscale x 4 x half*>
+  call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %t4, i32 2, <vscale x 4 x i1> %pg)
+  ret void
+}
+
 
 attributes #0 = { "target-features"="+sve" vscale_range(1, 16) }
 
+declare <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x float*>, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
 
 declare <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*>, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
 declare void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8>, <vscale x 4 x i8*>, i32, <vscale x 4 x i1>)


        


More information about the llvm-commits mailing list