[llvm] [AArch64] Extend custom lowering for SVE types in `@llvm.experimental.vector.compress` (PR #105515)
Lawrence Benson via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 21 05:49:40 PDT 2024
https://github.com/lawben created https://github.com/llvm/llvm-project/pull/105515
This is a follow-up to #101015. We now support `@llvm.experimental.vector.compress` for SVE types that don't map directly to `compact`, i.e., `<vscale x 8 x ..>` and `<vscale x 16 x ..>`. We can also use this logic for corresponding NEON vectors.
>From 7f210aeb289915b36b3c3e206f867655c95676ea Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Wed, 21 Aug 2024 14:32:39 +0200
Subject: [PATCH 1/3] Add support for SVE vectors with min. 8 and 16 elements
---
.../Target/AArch64/AArch64ISelLowering.cpp | 80 ++++++++++++----
.../CodeGen/AArch64/sve-vector-compress.ll | 95 +++++++++++++++++++
2 files changed, 159 insertions(+), 16 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e1d265fdf0d1a8..7894c9934c7d40 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1781,10 +1781,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v2f32, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
- // We can lower types that have <vscale x {2|4}> elements to compact.
+ // We can lower all legal (or smaller) SVE types to `compact`.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
- MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
+ MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32, MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
// If we have SVE, we can use SVE logic for legal (or smaller than legal)
@@ -6659,12 +6659,9 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
return SDValue();
- // Only <vscale x {4|2} x {i32|i64}> supported for compact.
- if (MinElmts != 2 && MinElmts != 4)
- return SDValue();
-
// We can use the SVE register containing the NEON vector in its lowest bits.
- if (IsFixedLength) {
+ // TODO: check if ==2|4 is needed
+ if (IsFixedLength && (MinElmts == 2 || MinElmts == 4)) {
EVT ScalableVecVT =
MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
EVT ScalableMaskVT = MVT::getScalableVectorVT(
@@ -6690,16 +6687,67 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
EVT ContainerVT = getSVEContainerType(VecVT);
EVT CastVT = VecVT.changeVectorElementTypeToInteger();
- // Convert to i32 or i64 for smaller types, as these are the only supported
- // sizes for compact.
- if (ContainerVT != VecVT) {
- Vec = DAG.getBitcast(CastVT, Vec);
- Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
- }
+ // These vector types aren't supported by the `compact` instruction, so
+ // we split and compact them as <vscale x 4 x i32>, store them on the stack,
+ // and then merge them again. In the other cases, emit compact directly.
+ SDValue Compressed;
+ if (VecVT == MVT::nxv8i16 || VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8) {
+ SDValue Chain = DAG.getEntryNode();
+ SDValue StackPtr = DAG.CreateStackTemporary(
+ VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
+ MachineFunction &MF = DAG.getMachineFunction();
+
+ EVT PartialVecVT =
+ EVT::getVectorVT(*DAG.getContext(), ElmtVT, 4, /*isScalable*/ true);
+ EVT OffsetVT = getVectorIdxTy(DAG.getDataLayout());
+ SDValue Offset = DAG.getConstant(0, DL, OffsetVT);
+
+ for (unsigned I = 0; I < MinElmts; I += 4) {
+ SDValue PartialVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartialVecVT,
+ Vec, DAG.getVectorIdxConstant(I, DL));
+ PartialVec = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv4i32, PartialVec);
+
+ SDValue PartialMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv4i1,
+ Mask, DAG.getVectorIdxConstant(I, DL));
+
+ SDValue PartialCompressed = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
+ DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64),
+ PartialMask, PartialVec);
+ PartialCompressed =
+ DAG.getNode(ISD::TRUNCATE, DL, PartialVecVT, PartialCompressed);
+
+ SDValue OutPtr = DAG.getNode(
+ ISD::ADD, DL, StackPtr.getValueType(), StackPtr,
+ DAG.getNode(
+ ISD::MUL, DL, OffsetVT, Offset,
+ DAG.getConstant(ElmtVT.getScalarSizeInBits() / 8, DL, OffsetVT)));
+ Chain = DAG.getStore(Chain, DL, PartialCompressed, OutPtr,
+ MachinePointerInfo::getUnknownStack(MF));
+
+ SDValue PartialOffset = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, OffsetVT,
+ DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
+ PartialMask, PartialMask);
+ Offset = DAG.getNode(ISD::ADD, DL, OffsetVT, Offset, PartialOffset);
+ }
+
+ MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
+ MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
+ Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
+ } else {
+ // Convert to i32 or i64 for smaller types, as these are the only supported
+ // sizes for compact.
+ if (ContainerVT != VecVT) {
+ Vec = DAG.getBitcast(CastVT, Vec);
+ Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
+ }
- SDValue Compressed = DAG.getNode(
- ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
- DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
+ Compressed = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
+ DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask,
+ Vec);
+ }
// compact fills with 0s, so if our passthru is all 0s, do nothing here.
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
diff --git a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
index 84c15e4fbc33c7..0501b0fece0c88 100644
--- a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
+++ b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
@@ -91,6 +91,101 @@ define <vscale x 4 x float> @test_compress_nxv4f32(<vscale x 4 x float> %vec, <v
ret <vscale x 4 x float> %out
}
+define <vscale x 8 x i8> @test_compress_nxv8i8(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv8i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: addpl x9, sp, #4
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: cntp x8, p1, p1.s
+; CHECK-NEXT: compact z1.s, p1, z1.s
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: ptrue p1.h
+; CHECK-NEXT: st1b { z1.s }, p0, [sp, #2, mul vl]
+; CHECK-NEXT: st1b { z0.s }, p0, [x9, x8]
+; CHECK-NEXT: ld1b { z0.h }, p1/z, [sp, #1, mul vl]
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <vscale x 8 x i8> @llvm.experimental.vector.compress(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
+ ret <vscale x 8 x i8> %out
+}
+
+define <vscale x 8 x i16> @test_compress_nxv8i16(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv8i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: cntp x8, p1, p1.s
+; CHECK-NEXT: compact z1.s, p1, z1.s
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: ptrue p1.h
+; CHECK-NEXT: st1h { z1.s }, p0, [sp]
+; CHECK-NEXT: st1h { z0.s }, p0, [x9, x8, lsl #1]
+; CHECK-NEXT: ld1h { z0.h }, p1/z, [sp]
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <vscale x 8 x i16> @llvm.experimental.vector.compress(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
+ ret <vscale x 8 x i16> %out
+}
+
+define <vscale x 16 x i8> @test_compress_nxv16i8(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: test_compress_nxv16i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: uunpklo z1.h, z0.b
+; CHECK-NEXT: punpklo p2.h, p0.b
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: uunpkhi z0.h, z0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: punpklo p3.h, p2.b
+; CHECK-NEXT: punpkhi p2.h, p2.b
+; CHECK-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: cntp x8, p3, p3.s
+; CHECK-NEXT: uunpklo z3.s, z0.h
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: compact z2.s, p3, z2.s
+; CHECK-NEXT: compact z1.s, p2, z1.s
+; CHECK-NEXT: punpklo p3.h, p0.b
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: ptrue p0.b
+; CHECK-NEXT: st1b { z2.s }, p1, [sp]
+; CHECK-NEXT: st1b { z1.s }, p1, [x9, x8]
+; CHECK-NEXT: compact z1.s, p3, z3.s
+; CHECK-NEXT: incp x8, p2.s
+; CHECK-NEXT: st1b { z1.s }, p1, [x9, x8]
+; CHECK-NEXT: incp x8, p3.s
+; CHECK-NEXT: st1b { z0.s }, p1, [x9, x8]
+; CHECK-NEXT: ld1b { z0.b }, p0/z, [sp]
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <vscale x 16 x i8> @llvm.experimental.vector.compress(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
+ ret <vscale x 16 x i8> %out
+}
+
define <vscale x 4 x i4> @test_compress_illegal_element_type(<vscale x 4 x i4> %vec, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: test_compress_illegal_element_type:
; CHECK: // %bb.0:
>From 9e0fd027b6b959b830e4b067f9e439ecf633005e Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Wed, 21 Aug 2024 14:45:28 +0200
Subject: [PATCH 2/3] Extend SVE approach to larger NEON vectors
---
.../Target/AArch64/AArch64ISelLowering.cpp | 6 ++--
.../CodeGen/AArch64/sve-vector-compress.ll | 34 +++++++++++++++++++
2 files changed, 37 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7894c9934c7d40..02b0c8798c625f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1790,7 +1790,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
- MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
+ MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32,
+ MVT::v8i8, MVT::v8i16, MVT::v16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
// Histcnt is SVE2 only
@@ -6660,8 +6661,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
return SDValue();
// We can use the SVE register containing the NEON vector in its lowest bits.
- // TODO: check if ==2|4 is needed
- if (IsFixedLength && (MinElmts == 2 || MinElmts == 4)) {
+ if (IsFixedLength) {
EVT ScalableVecVT =
MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
EVT ScalableMaskVT = MVT::getScalableVectorVT(
diff --git a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
index 0501b0fece0c88..fc8cbea0d47156 100644
--- a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
+++ b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll
@@ -335,6 +335,40 @@ define <2 x i16> @test_compress_v2i16_with_sve(<2 x i16> %vec, <2 x i1> %mask) {
ret <2 x i16> %out
}
+define <8 x i16> @test_compress_v8i16_with_sve(<8 x i16> %vec, <8 x i1> %mask) {
+; CHECK-LABEL: test_compress_v8i16_with_sve:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: addvl sp, sp, #-1
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and z1.h, z1.h, #0x1
+; CHECK-NEXT: cmpne p1.h, p0/z, z1.h, #0
+; CHECK-NEXT: uunpklo z1.s, z0.h
+; CHECK-NEXT: uunpkhi z0.s, z0.h
+; CHECK-NEXT: punpklo p2.h, p1.b
+; CHECK-NEXT: punpkhi p1.h, p1.b
+; CHECK-NEXT: compact z1.s, p2, z1.s
+; CHECK-NEXT: cntp x8, p2, p2.s
+; CHECK-NEXT: compact z0.s, p1, z0.s
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: st1h { z1.s }, p1, [sp]
+; CHECK-NEXT: st1h { z0.s }, p1, [x9, x8, lsl #1]
+; CHECK-NEXT: ld1h { z0.h }, p0/z, [sp]
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: addvl sp, sp, #1
+; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ %out = call <8 x i16> @llvm.experimental.vector.compress(<8 x i16> %vec, <8 x i1> %mask, <8 x i16> undef)
+ ret <8 x i16> %out
+}
+
define <vscale x 4 x i32> @test_compress_nxv4i32_with_passthru(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) {
; CHECK-LABEL: test_compress_nxv4i32_with_passthru:
>From 36f7a7ed53f9093edcd8cdefe2e3f7d44d89b4ca Mon Sep 17 00:00:00 2001
From: Lawrence Benson <github at lawben.com>
Date: Wed, 21 Aug 2024 14:46:25 +0200
Subject: [PATCH 3/3] Fix formatting
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 02b0c8798c625f..f4d3fa114ddc3d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1784,7 +1784,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// We can lower all legal (or smaller) SVE types to `compact`.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
- MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32, MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
+ MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
+ MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
// If we have SVE, we can use SVE logic for legal (or smaller than legal)
More information about the llvm-commits
mailing list