[llvm] 0d153df - [SVE] Fix selection failure when splitting extended masked loads

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 21 05:05:22 PDT 2021


Author: Kerry McLaughlin
Date: 2021-10-21T13:04:38+01:00
New Revision: 0d153df69e8fe28bdf7e65195d3708f331106088

URL: https://github.com/llvm/llvm-project/commit/0d153df69e8fe28bdf7e65195d3708f331106088
DIFF: https://github.com/llvm/llvm-project/commit/0d153df69e8fe28bdf7e65195d3708f331106088.diff

LOG: [SVE] Fix selection failure when splitting extended masked loads

When splitting a masked load, `GetDependentSplitDestVTs` is used to get the
MemVTs of the high and low parts. If the masked load is extended, this
may return VTs with different element types which are used to create the
high & low masked load instructions.
This patch changes `GetDependentSplitDestVTs` to ensure we return VTs with
the same element type.

Reviewed By: david-arm

Differential Revision: https://reviews.llvm.org/D111996

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
    llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 0a5dce34cd3ab..b928fd3b78d4a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -10642,14 +10642,14 @@ SelectionDAG::GetDependentSplitDestVTs(const EVT &VT, const EVT &EnvVT,
          "Mixing fixed width and scalable vectors when enveloping a type");
   EVT LoVT, HiVT;
   if (VTNumElts.getKnownMinValue() > EnvNumElts.getKnownMinValue()) {
-    LoVT = EnvVT;
+    LoVT = EVT::getVectorVT(*getContext(), EltTp, EnvNumElts);
     HiVT = EVT::getVectorVT(*getContext(), EltTp, VTNumElts - EnvNumElts);
     *HiIsEmpty = false;
   } else {
     // Flag that hi type has zero storage size, but return split envelop type
     // (this would be easier if vector types with zero elements were allowed).
     LoVT = EVT::getVectorVT(*getContext(), EltTp, VTNumElts);
-    HiVT = EnvVT;
+    HiVT = EVT::getVectorVT(*getContext(), EltTp, EnvNumElts);
     *HiIsEmpty = true;
   }
   return std::make_pair(LoVT, HiVT);

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
index f7efa546bb0b6..9a20035d3c798 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
@@ -70,9 +70,29 @@ define <vscale x 2 x i64> @masked_sload_passthru(<vscale x 2 x i32> *%a, <vscale
   ret <vscale x 2 x i64> %ext
 }
 
+; Return type requires splitting
+define <vscale x 16 x i32> @masked_sload_nxv16i8(<vscale x 16 x i8>* %a, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv16i8:
+; CHECK:         punpklo p1.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    punpklo p2.h, p1.b
+; CHECK-NEXT:    punpkhi p1.h, p1.b
+; CHECK-NEXT:    ld1sb { z0.s }, p2/z, [x0]
+; CHECK-NEXT:    punpklo p2.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    ld1sb { z1.s }, p1/z, [x0, #1, mul vl]
+; CHECK-NEXT:    ld1sb { z2.s }, p2/z, [x0, #2, mul vl]
+; CHECK-NEXT:    ld1sb { z3.s }, p0/z, [x0, #3, mul vl]
+; CHECK-NEXT:    ret
+  %load = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8(<vscale x 16 x i8>* %a, i32 2, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
+  %ext = sext <vscale x 16 x i8> %load to <vscale x 16 x i32>
+  ret <vscale x 16 x i32> %ext
+}
+
 declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
 declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
 declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)
 declare <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>*, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
 declare <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>*, i32, <vscale x 4 x i1>, <vscale x 4 x i16>)
 declare <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>*, i32, <vscale x 8 x i1>, <vscale x 8 x i8>)
+declare <vscale x 16 x i8> @llvm.masked.load.nxv16i8(<vscale x 16 x i8>*, i32, <vscale x 16 x i1>, <vscale x 16 x i8>)

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
index 7dbebee3d1eec..79eff4d7c572e 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
@@ -76,9 +76,29 @@ define <vscale x 2 x i64> @masked_zload_passthru(<vscale x 2 x i32>* %src, <vsca
   ret <vscale x 2 x i64> %ext
 }
 
+; Return type requires splitting
+define <vscale x 8 x i64> @masked_zload_nxv8i16(<vscale x 8 x i16>* %a, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv8i16:
+; CHECK:       punpklo p1.h, p0.b
+; CHECK-NEXT:  punpkhi p0.h, p0.b
+; CHECK-NEXT:  punpklo p2.h, p1.b
+; CHECK-NEXT:  punpkhi p1.h, p1.b
+; CHECK-NEXT:  ld1h { z0.d }, p2/z, [x0]
+; CHECK-NEXT:  punpklo p2.h, p0.b
+; CHECK-NEXT:  punpkhi p0.h, p0.b
+; CHECK-NEXT:  ld1h { z1.d }, p1/z, [x0, #1, mul vl]
+; CHECK-NEXT:  ld1h { z2.d }, p2/z, [x0, #2, mul vl]
+; CHECK-NEXT:  ld1h { z3.d }, p0/z, [x0, #3, mul vl]
+; CHECK-NEXT:  ret
+  %load = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16(<vscale x 8 x i16>* %a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
+  %ext = zext <vscale x 8 x i16> %load to <vscale x 8 x i64>
+  ret <vscale x 8 x i64> %ext
+}
+
 declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
 declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
 declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)
 declare <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>*, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
 declare <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>*, i32, <vscale x 4 x i1>, <vscale x 4 x i16>)
 declare <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>*, i32, <vscale x 8 x i1>, <vscale x 8 x i8>)
+declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16(<vscale x 8 x i16>*, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)


        


More information about the llvm-commits mailing list