[llvm] [AArch64] Extend custom lowering for SVE types in `@llvm.experimental.vector.compress` (PR #105515)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 21 05:50:16 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Lawrence Benson (lawben)

<details>
<summary>Changes</summary>

This is a follow-up to #<!-- -->101015. We now support `@<!-- -->llvm.experimental.vector.compress` for SVE types that don't map directly to `compact`, i.e., `<vscale x 8 x ..>` and `<vscale x 16 x ..>`. We can also use this logic for corresponding NEON vectors.

---
Full diff: https://github.com/llvm/llvm-project/pull/105515.diff


2 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+65-16) 
- (modified) llvm/test/CodeGen/AArch64/sve-vector-compress.ll (+129) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e1d265fdf0d1a8..f4d3fa114ddc3d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1781,16 +1781,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                     MVT::v2f32, MVT::v4f32, MVT::v2f64})
       setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
 
-    // We can lower types that have <vscale x {2|4}> elements to compact.
+    // We can lower all legal (or smaller) SVE types to `compact`.
     for (auto VT :
          {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
-          MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
+          MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
+          MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
       setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
 
     // If we have SVE, we can use SVE logic for legal (or smaller than legal)
     // NEON vectors in the lowest bits of the SVE register.
     for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
-                    MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
+                    MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32,
+                    MVT::v8i8, MVT::v8i16, MVT::v16i8})
       setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
 
     // Histcnt is SVE2 only
@@ -6659,10 +6661,6 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
   if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
     return SDValue();
 
-  // Only <vscale x {4|2} x {i32|i64}> supported for compact.
-  if (MinElmts != 2 && MinElmts != 4)
-    return SDValue();
-
   // We can use the SVE register containing the NEON vector in its lowest bits.
   if (IsFixedLength) {
     EVT ScalableVecVT =
@@ -6690,16 +6688,67 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
   EVT ContainerVT = getSVEContainerType(VecVT);
   EVT CastVT = VecVT.changeVectorElementTypeToInteger();
 
-  // Convert to i32 or i64 for smaller types, as these are the only supported
-  // sizes for compact.
-  if (ContainerVT != VecVT) {
-    Vec = DAG.getBitcast(CastVT, Vec);
-    Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
-  }
+  // These vector types aren't supported by the `compact` instruction, so
+  // we split and compact them as <vscale x 4 x i32>, store them on the stack,
+  // and then merge them again. In the other cases, emit compact directly.
+  SDValue Compressed;
+  if (VecVT == MVT::nxv8i16 || VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8) {
+    SDValue Chain = DAG.getEntryNode();
+    SDValue StackPtr = DAG.CreateStackTemporary(
+        VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+    MachineFunction &MF = DAG.getMachineFunction();
+
+    EVT PartialVecVT =
+        EVT::getVectorVT(*DAG.getContext(), ElmtVT, 4, /*isScalable*/ true);
+    EVT OffsetVT = getVectorIdxTy(DAG.getDataLayout());
+    SDValue Offset = DAG.getConstant(0, DL, OffsetVT);
+
+    for (unsigned I = 0; I < MinElmts; I += 4) {
+      SDValue PartialVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartialVecVT,
+                                       Vec, DAG.getVectorIdxConstant(I, DL));
+      PartialVec = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv4i32, PartialVec);
+
+      SDValue PartialMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv4i1,
+                                        Mask, DAG.getVectorIdxConstant(I, DL));
+
+      SDValue PartialCompressed = DAG.getNode(
+          ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
+          DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64),
+          PartialMask, PartialVec);
+      PartialCompressed =
+          DAG.getNode(ISD::TRUNCATE, DL, PartialVecVT, PartialCompressed);
+
+      SDValue OutPtr = DAG.getNode(
+          ISD::ADD, DL, StackPtr.getValueType(), StackPtr,
+          DAG.getNode(
+              ISD::MUL, DL, OffsetVT, Offset,
+              DAG.getConstant(ElmtVT.getScalarSizeInBits() / 8, DL, OffsetVT)));
+      Chain = DAG.getStore(Chain, DL, PartialCompressed, OutPtr,
+                           MachinePointerInfo::getUnknownStack(MF));
+
+      SDValue PartialOffset = DAG.getNode(
+          ISD::INTRINSIC_WO_CHAIN, DL, OffsetVT,
+          DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
+          PartialMask, PartialMask);
+      Offset = DAG.getNode(ISD::ADD, DL, OffsetVT, Offset, PartialOffset);
+    }
+
+    MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
+        MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
+    Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+  } else {
+    // Convert to i32 or i64 for smaller types, as these are the only supported
+    // sizes for compact.
+    if (ContainerVT != VecVT) {
+      Vec = DAG.getBitcast(CastVT, Vec);
+      Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
+    }
 
-  SDValue Compressed = DAG.getNode(
-      ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
-      DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
+    Compressed = DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
+        DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask,
+        Vec);
+  }
 
   // compact fills with 0s, so if our passthru is all 0s, do nothing here.
   if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
diff --git a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
index 84c15e4fbc33c7..fc8cbea0d47156 100644
--- a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
+++ b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
@@ -91,6 +91,101 @@ define <vscale x 4 x float> @test_compress_nxv4f32(<vscale x 4 x float> %vec, <v
     ret <vscale x 4 x float> %out
 }
 
+define <vscale x 8 x i8> @test_compress_nxv8i8(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv8i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    uunpklo z1.s, z0.h
+; CHECK-NEXT:    uunpkhi z0.s, z0.h
+; CHECK-NEXT:    addpl x9, sp, #4
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    cntp x8, p1, p1.s
+; CHECK-NEXT:    compact z1.s, p1, z1.s
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    st1b { z1.s }, p0, [sp, #2, mul vl]
+; CHECK-NEXT:    st1b { z0.s }, p0, [x9, x8]
+; CHECK-NEXT:    ld1b { z0.h }, p1/z, [sp, #1, mul vl]
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+    %out = call <vscale x 8 x i8> @llvm.experimental.vector.compress(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
+    ret <vscale x 8 x i8> %out
+}
+
+define <vscale x 8 x i16> @test_compress_nxv8i16(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    uunpklo z1.s, z0.h
+; CHECK-NEXT:    uunpkhi z0.s, z0.h
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    cntp x8, p1, p1.s
+; CHECK-NEXT:    compact z1.s, p1, z1.s
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    st1h { z1.s }, p0, [sp]
+; CHECK-NEXT:    st1h { z0.s }, p0, [x9, x8, lsl #1]
+; CHECK-NEXT:    ld1h { z0.h }, p1/z, [sp]
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+    %out = call <vscale x 8 x i16> @llvm.experimental.vector.compress(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
+    ret <vscale x 8 x i16> %out
+}
+
+define <vscale x 16 x i8> @test_compress_nxv16i8(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv16i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    uunpklo z1.h, z0.b
+; CHECK-NEXT:    punpklo p2.h, p0.b
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    uunpkhi z0.h, z0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    punpklo p3.h, p2.b
+; CHECK-NEXT:    punpkhi p2.h, p2.b
+; CHECK-NEXT:    uunpklo z2.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    cntp x8, p3, p3.s
+; CHECK-NEXT:    uunpklo z3.s, z0.h
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    uunpkhi z0.s, z0.h
+; CHECK-NEXT:    compact z2.s, p3, z2.s
+; CHECK-NEXT:    compact z1.s, p2, z1.s
+; CHECK-NEXT:    punpklo p3.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    compact z0.s, p0, z0.s
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    st1b { z2.s }, p1, [sp]
+; CHECK-NEXT:    st1b { z1.s }, p1, [x9, x8]
+; CHECK-NEXT:    compact z1.s, p3, z3.s
+; CHECK-NEXT:    incp x8, p2.s
+; CHECK-NEXT:    st1b { z1.s }, p1, [x9, x8]
+; CHECK-NEXT:    incp x8, p3.s
+; CHECK-NEXT:    st1b { z0.s }, p1, [x9, x8]
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [sp]
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+    %out = call <vscale x 16 x i8> @llvm.experimental.vector.compress(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
+    ret <vscale x 16 x i8> %out
+}
+
 define <vscale x 4 x i4> @test_compress_illegal_element_type(<vscale x 4 x i4> %vec, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: test_compress_illegal_element_type:
 ; CHECK:       // %bb.0:
@@ -240,6 +335,40 @@ define <2 x i16> @test_compress_v2i16_with_sve(<2 x i16> %vec, <2 x i1> %mask) {
     ret <2 x i16> %out
 }
 
+define <8 x i16> @test_compress_v8i16_with_sve(<8 x i16> %vec, <8 x i1> %mask) {
+; CHECK-LABEL: test_compress_v8i16_with_sve:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    shl v1.8h, v1.8h, #15
+; CHECK-NEXT:    cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT:    and z1.h, z1.h, #0x1
+; CHECK-NEXT:    cmpne p1.h, p0/z, z1.h, #0
+; CHECK-NEXT:    uunpklo z1.s, z0.h
+; CHECK-NEXT:    uunpkhi z0.s, z0.h
+; CHECK-NEXT:    punpklo p2.h, p1.b
+; CHECK-NEXT:    punpkhi p1.h, p1.b
+; CHECK-NEXT:    compact z1.s, p2, z1.s
+; CHECK-NEXT:    cntp x8, p2, p2.s
+; CHECK-NEXT:    compact z0.s, p1, z0.s
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    st1h { z1.s }, p1, [sp]
+; CHECK-NEXT:    st1h { z0.s }, p1, [x9, x8, lsl #1]
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [sp]
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+    %out = call <8 x i16> @llvm.experimental.vector.compress(<8 x i16> %vec, <8 x i1> %mask, <8 x i16> undef)
+    ret <8 x i16> %out
+}
+
 
 define <vscale x 4 x i32> @test_compress_nxv4i32_with_passthru(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) {
 ; CHECK-LABEL: test_compress_nxv4i32_with_passthru:

``````````

</details>


https://github.com/llvm/llvm-project/pull/105515


More information about the llvm-commits mailing list