[llvm] 73ecb99 - [AArch64][SVE] Fix assertion failure when lowering fixed length gather/scatter

Bradley Smith via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 9 07:05:39 PDT 2021


Author: Bradley Smith
Date: 2021-08-09T14:05:22Z
New Revision: 73ecb9987b00db274b7b2ac34b0602ffdb906a4b

URL: https://github.com/llvm/llvm-project/commit/73ecb9987b00db274b7b2ac34b0602ffdb906a4b
DIFF: https://github.com/llvm/llvm-project/commit/73ecb9987b00db274b7b2ac34b0602ffdb906a4b.diff

LOG: [AArch64][SVE] Fix assertion failure when lowering fixed length gather/scatter

The patterns for fixed length gather/scatter with 32-bit offsets and
64-bit memory type are slightly different that the rest of the patterns,
as such the lowering needs to be slightly different to ensure the
correct types are used.

Differential Revision: https://reviews.llvm.org/D107576

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
    llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4bf2635d8caea..aa9cb463d8aa8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4360,8 +4360,13 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
   if (IsFixedLength) {
     assert(Subtarget->useSVEForFixedLengthVectors() &&
            "Cannot lower when not using SVE for fixed vectors");
-    IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
-    MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
+    if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
+      IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
+      MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
+    } else {
+      MemVT = getContainerForFixedLengthVector(DAG, MemVT);
+      IndexVT = MemVT.changeTypeToInteger();
+    }
     InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
     Mask = DAG.getNode(
         ISD::ZERO_EXTEND, DL,
@@ -4460,8 +4465,13 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
   if (IsFixedLength) {
     assert(Subtarget->useSVEForFixedLengthVectors() &&
            "Cannot lower when not using SVE for fixed vectors");
-    IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
-    MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
+    if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
+      IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
+      MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
+    } else {
+      MemVT = getContainerForFixedLengthVector(DAG, MemVT);
+      IndexVT = MemVT.changeTypeToInteger();
+    }
     InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
 
     StoreVal =

diff  --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
index 20b8d10ebe30a..ee64e6db65689 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
@@ -917,8 +917,8 @@ define void @masked_gather_v32f64(<32 x double>* %a, <32 x double*>* %b) #0 {
 ; The above tests test the types, the below tests check that the addressing
 ; modes still function
 
-define void @masked_gather_32b_scaled_sext(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 {
-; CHECK-LABEL: masked_gather_32b_scaled_sext:
+define void @masked_gather_32b_scaled_sext_f16(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 {
+; CHECK-LABEL: masked_gather_32b_scaled_sext_f16:
 ; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].h, vl32
 ; VBITS_GE_2048-NEXT: ld1h { [[VALS:z[0-9]+]].h }, [[PG0]]/z, [x0]
 ; VBITS_GE_2048-NEXT: ptrue [[PG1:p[0-9]+]].s, vl32
@@ -941,6 +941,45 @@ define void @masked_gather_32b_scaled_sext(<32 x half>* %a, <32 x i32>* %b, half
   ret void
 }
 
+define void @masked_gather_32b_scaled_sext_f32(<32 x float>* %a, <32 x i32>* %b, float* %base) #0 {
+; CHECK-LABEL: masked_gather_32b_scaled_sext_f32:
+; VBITS_GE_2048: ptrue [[PG:p[0-9]+]].s, vl32
+; VBITS_GE_2048-NEXT: ld1w { [[VALS:z[0-9]+]].s }, [[PG]]/z, [x0]
+; VBITS_GE_2048-NEXT: ld1w { [[PTRS:z[0-9]+]].s }, [[PG]]/z, [x1]
+; VBITS_GE_2048-NEXT: fcmeq [[MASK:p[0-9]+]].s, [[PG]]/z, [[VALS]].s, #0.0
+; VBITS_GE_2048-NEXT: ld1w { [[RES:z[0-9]+]].s }, [[MASK]]/z, [x2, [[PTRS]].s, sxtw #2]
+; VBITS_GE_2048-NEXT: st1w { [[RES]].s }, [[PG]], [x0]
+; VBITS_GE_2048-NEXT: ret
+  %cvals = load <32 x float>, <32 x float>* %a
+  %idxs = load <32 x i32>, <32 x i32>* %b
+  %ext = sext <32 x i32> %idxs to <32 x i64>
+  %ptrs = getelementptr float, float* %base, <32 x i64> %ext
+  %mask = fcmp oeq <32 x float> %cvals, zeroinitializer
+  %vals = call <32 x float> @llvm.masked.gather.v32f32(<32 x float*> %ptrs, i32 8, <32 x i1> %mask, <32 x float> undef)
+  store <32 x float> %vals, <32 x float>* %a
+  ret void
+}
+
+define void @masked_gather_32b_scaled_sext_f64(<32 x double>* %a, <32 x i32>* %b, double* %base) #0 {
+; CHECK-LABEL: masked_gather_32b_scaled_sext_f64:
+; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].d, vl32
+; VBITS_GE_2048-NEXT: ld1d { [[VALS:z[0-9]+]].d }, [[PG0]]/z, [x0]
+; VBITS_GE_2048-NEXT: ptrue [[PG1:p[0-9]+]].s, vl32
+; VBITS_GE_2048-NEXT: ld1w { [[PTRS:z[0-9]+]].s }, [[PG1]]/z, [x1]
+; VBITS_GE_2048-NEXT: fcmeq [[MASK:p[0-9]+]].d, [[PG0]]/z, [[VALS]].d, #0.0
+; VBITS_GE_2048-NEXT: ld1d { [[RES:z[0-9]+]].d }, [[MASK]]/z, [x2, [[PTRS]].d, sxtw #3]
+; VBITS_GE_2048-NEXT: st1d { [[RES]].d }, [[PG0]], [x0]
+; VBITS_GE_2048-NEXT: ret
+  %cvals = load <32 x double>, <32 x double>* %a
+  %idxs = load <32 x i32>, <32 x i32>* %b
+  %ext = sext <32 x i32> %idxs to <32 x i64>
+  %ptrs = getelementptr double, double* %base, <32 x i64> %ext
+  %mask = fcmp oeq <32 x double> %cvals, zeroinitializer
+  %vals = call <32 x double> @llvm.masked.gather.v32f64(<32 x double*> %ptrs, i32 8, <32 x i1> %mask, <32 x double> undef)
+  store <32 x double> %vals, <32 x double>* %a
+  ret void
+}
+
 define void @masked_gather_32b_scaled_zext(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 {
 ; CHECK-LABEL: masked_gather_32b_scaled_zext:
 ; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].h, vl32

diff  --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
index 60f65d190d188..ec0f4bdf025c6 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
@@ -839,9 +839,8 @@ define void @masked_scatter_v32f64(<32 x double>* %a, <32 x double*>* %b) #0 {
 
 ; The above tests test the types, the below tests check that the addressing
 ; modes still function
-
-define void @masked_scatter_32b_scaled_sext(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 {
-; CHECK-LABEL: masked_scatter_32b_scaled_sext:
+define void @masked_scatter_32b_scaled_sext_f16(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 {
+; CHECK-LABEL: masked_scatter_32b_scaled_sext_f16:
 ; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].h, vl32
 ; VBITS_GE_2048-NEXT: ld1h { [[VALS:z[0-9]+]].h }, [[PG0]]/z, [x0]
 ; VBITS_GE_2048-NEXT: ptrue [[PG1:p[0-9]+]].s, vl32
@@ -862,6 +861,41 @@ define void @masked_scatter_32b_scaled_sext(<32 x half>* %a, <32 x i32>* %b, hal
   ret void
 }
 
+define void @masked_scatter_32b_scaled_sext_f32(<32 x float>* %a, <32 x i32>* %b, float* %base) #0 {
+; CHECK-LABEL: masked_scatter_32b_scaled_sext_f32:
+; VBITS_GE_2048: ptrue [[PG:p[0-9]+]].s, vl32
+; VBITS_GE_2048-NEXT: ld1w { [[VALS:z[0-9]+]].s }, [[PG]]/z, [x0]
+; VBITS_GE_2048-NEXT: ld1w { [[PTRS:z[0-9]+]].s }, [[PG]]/z, [x1]
+; VBITS_GE_2048-NEXT: fcmeq [[MASK:p[0-9]+]].s, [[PG]]/z, [[VALS]].s, #0.0
+; VBITS_GE_2048-NEXT: st1w { [[VALS]].s }, [[MASK]], [x2, [[PTRS]].s, sxtw #2]
+; VBITS_GE_2048-NEXT: ret
+  %vals = load <32 x float>, <32 x float>* %a
+  %idxs = load <32 x i32>, <32 x i32>* %b
+  %ext = sext <32 x i32> %idxs to <32 x i64>
+  %ptrs = getelementptr float, float* %base, <32 x i64> %ext
+  %mask = fcmp oeq <32 x float> %vals, zeroinitializer
+  call void @llvm.masked.scatter.v32f32(<32 x float> %vals, <32 x float*> %ptrs, i32 8, <32 x i1> %mask)
+  ret void
+}
+
+define void @masked_scatter_32b_scaled_sext_f64(<32 x double>* %a, <32 x i32>* %b, double* %base) #0 {
+; CHECK-LABEL: masked_scatter_32b_scaled_sext_f64:
+; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].d, vl32
+; VBITS_GE_2048-NEXT: ld1d { [[VALS:z[0-9]+]].d }, [[PG0]]/z, [x0]
+; VBITS_GE_2048-NEXT: ptrue [[PG1:p[0-9]+]].s, vl32
+; VBITS_GE_2048-NEXT: ld1w { [[PTRS:z[0-9]+]].s }, [[PG1]]/z, [x1]
+; VBITS_GE_2048-NEXT: fcmeq [[MASK:p[0-9]+]].d, [[PG0]]/z, [[VALS]].d, #0.0
+; VBITS_GE_2048-NEXT: st1d { [[VALS]].d }, [[MASK]], [x2, [[PTRS]].d, sxtw #3]
+; VBITS_GE_2048-NEXT: ret
+  %vals = load <32 x double>, <32 x double>* %a
+  %idxs = load <32 x i32>, <32 x i32>* %b
+  %ext = sext <32 x i32> %idxs to <32 x i64>
+  %ptrs = getelementptr double, double* %base, <32 x i64> %ext
+  %mask = fcmp oeq <32 x double> %vals, zeroinitializer
+  call void @llvm.masked.scatter.v32f64(<32 x double> %vals, <32 x double*> %ptrs, i32 8, <32 x i1> %mask)
+  ret void
+}
+
 define void @masked_scatter_32b_scaled_zext(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 {
 ; CHECK-LABEL: masked_scatter_32b_scaled_zext:
 ; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].h, vl32


        


More information about the llvm-commits mailing list