[llvm] cc68e05 - [SVE][CodeGen] Improve codegen for some zero-extends of masked loads

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 17 01:19:33 PDT 2023


Author: David Sherwood
Date: 2023-07-17T08:19:27Z
New Revision: cc68e05bd2cc477f7286b5b566558e990f0e5075

URL: https://github.com/llvm/llvm-project/commit/cc68e05bd2cc477f7286b5b566558e990f0e5075
DIFF: https://github.com/llvm/llvm-project/commit/cc68e05bd2cc477f7286b5b566558e990f0e5075.diff

LOG: [SVE][CodeGen] Improve codegen for some zero-extends of masked loads

When doing a masked load of an illegal unpacked type and then
zero-extending to some illegal wider types we sometimes end up
with pointless 'and' instructions that are trying to zero bits
that we already know are zero. This patch fixes that by adding
more cases to performSVEAndCombine.

Differential Revision: https://reviews.llvm.org/D155281

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-intrinsics-mask-ldst-ext.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2a2953c359984a..8059592623713f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17053,14 +17053,28 @@ static SDValue performSVEAndCombine(SDNode *N,
 
     uint64_t ExtVal = C->getZExtValue();
 
+    auto MaskAndTypeMatch = [ExtVal](EVT VT) -> bool {
+      return ((ExtVal == 0xFF && VT == MVT::i8) ||
+              (ExtVal == 0xFFFF && VT == MVT::i16) ||
+              (ExtVal == 0xFFFFFFFF && VT == MVT::i32));
+    };
+
     // If the mask is fully covered by the unpack, we don't need to push
     // a new AND onto the operand
     EVT EltTy = UnpkOp->getValueType(0).getVectorElementType();
-    if ((ExtVal == 0xFF && EltTy == MVT::i8) ||
-        (ExtVal == 0xFFFF && EltTy == MVT::i16) ||
-        (ExtVal == 0xFFFFFFFF && EltTy == MVT::i32))
+    if (MaskAndTypeMatch(EltTy))
       return Src;
 
+    // If this is 'and (uunpklo/hi (extload MemTy -> ExtTy)), mask', then check
+    // to see if the mask is all-ones of size MemTy.
+    auto MaskedLoadOp = dyn_cast<MaskedLoadSDNode>(UnpkOp);
+    if (MaskedLoadOp && (MaskedLoadOp->getExtensionType() == ISD::ZEXTLOAD ||
+                         MaskedLoadOp->getExtensionType() == ISD::EXTLOAD)) {
+      EVT EltTy = MaskedLoadOp->getMemoryVT().getVectorElementType();
+      if (MaskAndTypeMatch(EltTy))
+        return Src;
+    }
+
     // Truncate to prevent a DUP with an over wide constant
     APInt Mask = C->getAPIntValue().trunc(EltTy.getSizeInBits());
 

diff  --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-mask-ldst-ext.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-mask-ldst-ext.ll
index 55bd3833f611c0..c495e983818f78 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-mask-ldst-ext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-mask-ldst-ext.ll
@@ -52,10 +52,9 @@ define <vscale x 16 x i32> @masked_ld1b_i8_zext_i32(<vscale x 16 x i8> *%base, <
 define <vscale x 8 x i32> @masked_ld1b_nxv8i8_zext_i32(<vscale x 8 x i8> *%a, <vscale x 8 x i1> %mask) {
 ; CHECK-LABEL: masked_ld1b_nxv8i8_zext_i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ld1b { z0.h }, p0/z, [x0]
-; CHECK-NEXT:    uunpkhi z1.s, z0.h
-; CHECK-NEXT:    and z0.h, z0.h, #0xff
-; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    ld1b { z1.h }, p0/z, [x0]
+; CHECK-NEXT:    uunpklo z0.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
 ; CHECK-NEXT:    ret
   %wide.masked.load = call <vscale x 8 x i8> @llvm.masked.load.nxv8i8.p0(ptr %a, i32 1, <vscale x 8 x i1> %mask, <vscale x 8 x i8> poison)
   %res = zext <vscale x 8 x i8> %wide.masked.load to <vscale x 8 x i32>
@@ -125,10 +124,9 @@ define <vscale x 16 x i64> @masked_ld1b_i8_zext(<vscale x 16 x i8> *%base, <vsca
 define <vscale x 4 x i64> @masked_ld1b_nxv4i8_zext_i64(<vscale x 4 x i8> *%a, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: masked_ld1b_nxv4i8_zext_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ld1b { z0.s }, p0/z, [x0]
-; CHECK-NEXT:    uunpkhi z1.d, z0.s
-; CHECK-NEXT:    and z0.s, z0.s, #0xff
-; CHECK-NEXT:    uunpklo z0.d, z0.s
+; CHECK-NEXT:    ld1b { z1.s }, p0/z, [x0]
+; CHECK-NEXT:    uunpklo z0.d, z1.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
 ; CHECK-NEXT:    ret
   %wide.masked.load = call <vscale x 4 x i8> @llvm.masked.load.nxv4i8.p0(ptr %a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i8> poison)
   %res = zext <vscale x 4 x i8> %wide.masked.load to <vscale x 4 x i64>
@@ -186,10 +184,9 @@ define <vscale x 8 x i64> @masked_ld1h_i16_zext(<vscale x 8 x i16> *%base, <vsca
 define <vscale x 4 x i64> @masked_ld1h_nxv4i16_zext(<vscale x 4 x i16> *%a, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: masked_ld1h_nxv4i16_zext:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ld1h { z0.s }, p0/z, [x0]
-; CHECK-NEXT:    uunpkhi z1.d, z0.s
-; CHECK-NEXT:    and z0.s, z0.s, #0xffff
-; CHECK-NEXT:    uunpklo z0.d, z0.s
+; CHECK-NEXT:    ld1h { z1.s }, p0/z, [x0]
+; CHECK-NEXT:    uunpklo z0.d, z1.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
 ; CHECK-NEXT:    ret
   %wide.masked.load = call <vscale x 4 x i16> @llvm.masked.load.nxv4i16.p0(ptr %a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i16> poison)
   %res = zext <vscale x 4 x i16> %wide.masked.load to <vscale x 4 x i64>


        


More information about the llvm-commits mailing list