[llvm] [AArch64] Extend custom lowering for SVE types in `@llvm.experimental.vector.compress` (PR #105515)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 21 05:50:16 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Lawrence Benson (lawben)
<details>
<summary>Changes</summary>
This is a follow-up to #<!-- -->101015. We now support `@<!-- -->llvm.experimental.vector.compress` for SVE types that don't map directly to `compact`, i.e., `<vscale x 8 x ..>` and `<vscale x 16 x ..>`. We can also use this logic for corresponding NEON vectors.
---
Full diff: https://github.com/llvm/llvm-project/pull/105515.diff
2 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+65-16)
- (modified) llvm/test/CodeGen/AArch64/sve-vector-compress.ll (+129)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e1d265fdf0d1a8..f4d3fa114ddc3d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1781,16 +1781,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v2f32, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
- // We can lower types that have <vscale x {2|4}> elements to compact.
+ // We can lower all legal (or smaller) SVE types to `compact`.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
- MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
+ MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
+ MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
- MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
+ MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32,
+ MVT::v8i8, MVT::v8i16, MVT::v16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
// Histcnt is SVE2 only
@@ -6659,10 +6661,6 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
return SDValue();
- // Only <vscale x {4|2} x {i32|i64}> supported for compact.
- if (MinElmts != 2 && MinElmts != 4)
- return SDValue();
-
// We can use the SVE register containing the NEON vector in its lowest bits.
if (IsFixedLength) {
EVT ScalableVecVT =
@@ -6690,16 +6688,67 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
EVT ContainerVT = getSVEContainerType(VecVT);
EVT CastVT = VecVT.changeVectorElementTypeToInteger();
- // Convert to i32 or i64 for smaller types, as these are the only supported
- // sizes for compact.
- if (ContainerVT != VecVT) {
- Vec = DAG.getBitcast(CastVT, Vec);
- Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
- }
+ // These vector types aren't supported by the `compact` instruction, so
+ // we split and compact them as <vscale x 4 x i32>, store them on the stack,
+ // and then merge them again. In the other cases, emit compact directly.
+ SDValue Compressed;
+ if (VecVT == MVT::nxv8i16 || VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8) {
+ SDValue Chain = DAG.getEntryNode();
+ SDValue StackPtr = DAG.CreateStackTemporary(
+ VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+ MachineFunction &MF = DAG.getMachineFunction();
+
+ EVT PartialVecVT =
+ EVT::getVectorVT(*DAG.getContext(), ElmtVT, 4, /*isScalable*/ true);
+ EVT OffsetVT = getVectorIdxTy(DAG.getDataLayout());
+ SDValue Offset = DAG.getConstant(0, DL, OffsetVT);
+
+ for (unsigned I = 0; I < MinElmts; I += 4) {
+ SDValue PartialVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartialVecVT,
+ Vec, DAG.getVectorIdxConstant(I, DL));
+ PartialVec = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv4i32, PartialVec);
+
+ SDValue PartialMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv4i1,
+ Mask, DAG.getVectorIdxConstant(I, DL));
+
+ SDValue PartialCompressed = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
+ DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64),
+ PartialMask, PartialVec);
+ PartialCompressed =
+ DAG.getNode(ISD::TRUNCATE, DL, PartialVecVT, PartialCompressed);
+
+ SDValue OutPtr = DAG.getNode(
+ ISD::ADD, DL, StackPtr.getValueType(), StackPtr,
+ DAG.getNode(
+ ISD::MUL, DL, OffsetVT, Offset,
+ DAG.getConstant(ElmtVT.getScalarSizeInBits() / 8, DL, OffsetVT)));
+ Chain = DAG.getStore(Chain, DL, PartialCompressed, OutPtr,
+ MachinePointerInfo::getUnknownStack(MF));
+
+ SDValue PartialOffset = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, OffsetVT,
+ DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
+ PartialMask, PartialMask);
+ Offset = DAG.getNode(ISD::ADD, DL, OffsetVT, Offset, PartialOffset);
+ }
+
+ MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
+ MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
+ Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+ } else {
+ // Convert to i32 or i64 for smaller types, as these are the only supported
+ // sizes for compact.
+ if (ContainerVT != VecVT) {
+ Vec = DAG.getBitcast(CastVT, Vec);
+ Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
+ }
- SDValue Compressed = DAG.getNode(
- ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
- DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
+ Compressed = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
+ DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask,
+ Vec);
+ }
// compact fills with 0s, so if our passthru is all 0s, do nothing here.
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
diff --git a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
index 84c15e4fbc33c7..fc8cbea0d47156 100644
--- a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
+++ b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
@@ -91,6 +91,101 @@ define <vscale x 4 x float> @test_compress_nxv4f32(<vscale x 4 x float> %vec, <v
ret <vscale x 4 x float> %out
}
+define <vscale x 8 x i8> @test_compress_nxv8i8(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv8i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: addpl x9, sp, #4
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: cntp x8, p1, p1.s
+; CHECK-NEXT: compact z1.s, p1, z1.s
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: ptrue p1.h
+; CHECK-NEXT: st1b { z1.s }, p0, [sp, #2, mul vl]
+; CHECK-NEXT: st1b { z0.s }, p0, [x9, x8]
+; CHECK-NEXT: ld1b { z0.h }, p1/z, [sp, #1, mul vl]
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <vscale x 8 x i8> @llvm.experimental.vector.compress(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
+ ret <vscale x 8 x i8> %out
+}
+
+define <vscale x 8 x i16> @test_compress_nxv8i16(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv8i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: cntp x8, p1, p1.s
+; CHECK-NEXT: compact z1.s, p1, z1.s
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: ptrue p1.h
+; CHECK-NEXT: st1h { z1.s }, p0, [sp]
+; CHECK-NEXT: st1h { z0.s }, p0, [x9, x8, lsl #1]
+; CHECK-NEXT: ld1h { z0.h }, p1/z, [sp]
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <vscale x 8 x i16> @llvm.experimental.vector.compress(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
+ ret <vscale x 8 x i16> %out
+}
+
+define <vscale x 16 x i8> @test_compress_nxv16i8(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv16i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: uunpklo z1.h, z0.b
+; CHECK-NEXT: punpklo p2.h, p0.b
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: uunpkhi z0.h, z0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: punpklo p3.h, p2.b
+; CHECK-NEXT: punpkhi p2.h, p2.b
+; CHECK-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: cntp x8, p3, p3.s
+; CHECK-NEXT: uunpklo z3.s, z0.h
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: compact z2.s, p3, z2.s
+; CHECK-NEXT: compact z1.s, p2, z1.s
+; CHECK-NEXT: punpklo p3.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: ptrue p0.b
+; CHECK-NEXT: st1b { z2.s }, p1, [sp]
+; CHECK-NEXT: st1b { z1.s }, p1, [x9, x8]
+; CHECK-NEXT: compact z1.s, p3, z3.s
+; CHECK-NEXT: incp x8, p2.s
+; CHECK-NEXT: st1b { z1.s }, p1, [x9, x8]
+; CHECK-NEXT: incp x8, p3.s
+; CHECK-NEXT: st1b { z0.s }, p1, [x9, x8]
+; CHECK-NEXT: ld1b { z0.b }, p0/z, [sp]
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <vscale x 16 x i8> @llvm.experimental.vector.compress(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
+ ret <vscale x 16 x i8> %out
+}
+
define <vscale x 4 x i4> @test_compress_illegal_element_type(<vscale x 4 x i4> %vec, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: test_compress_illegal_element_type:
; CHECK: // %bb.0:
@@ -240,6 +335,40 @@ define <2 x i16> @test_compress_v2i16_with_sve(<2 x i16> %vec, <2 x i1> %mask) {
ret <2 x i16> %out
}
+define <8 x i16> @test_compress_v8i16_with_sve(<8 x i16> %vec, <8 x i1> %mask) {
+; CHECK-LABEL: test_compress_v8i16_with_sve:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and z1.h, z1.h, #0x1
+; CHECK-NEXT: cmpne p1.h, p0/z, z1.h, #0
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: punpklo p2.h, p1.b
+; CHECK-NEXT: punpkhi p1.h, p1.b
+; CHECK-NEXT: compact z1.s, p2, z1.s
+; CHECK-NEXT: cntp x8, p2, p2.s
+; CHECK-NEXT: compact z0.s, p1, z0.s
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: st1h { z1.s }, p1, [sp]
+; CHECK-NEXT: st1h { z0.s }, p1, [x9, x8, lsl #1]
+; CHECK-NEXT: ld1h { z0.h }, p0/z, [sp]
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <8 x i16> @llvm.experimental.vector.compress(<8 x i16> %vec, <8 x i1> %mask, <8 x i16> undef)
+ ret <8 x i16> %out
+}
+
define <vscale x 4 x i32> @test_compress_nxv4i32_with_passthru(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) {
; CHECK-LABEL: test_compress_nxv4i32_with_passthru:
``````````
</details>
https://github.com/llvm/llvm-project/pull/105515
More information about the llvm-commits
mailing list