[llvm] [AArch64] Lower scalable i1 vector add reduction to cntp (PR #100118)
Max Beck-Jones via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 23 06:08:45 PDT 2024
https://github.com/DevM-uk created https://github.com/llvm/llvm-project/pull/100118
Doing an add reduction on a vector of i1 elements is the same as counting the number of set elements so such a reduction can be lowered to a cntp instruction. This saves a number of instructions over performing a UADDV. This patch only handles straightforward cases (i.e. when vectors are not split).
>From 2a8f5bbd621bfdbba9d69404c6961b54d4cc8ca4 Mon Sep 17 00:00:00 2001
From: Max Beck-Jones <max.beck-jones at arm.com>
Date: Tue, 16 Jul 2024 10:46:30 +0000
Subject: [PATCH 1/4] [AArch64] Lower scalable i1 vector add reduction to cntp
Doing an add reduction on a vector of i1 elements is the same as counting the number of set elements so such a reduction can be lowered to a cntp instruction. This saves a number of instructions over performing a UADDV.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 15 ++
.../test/CodeGen/AArch64/sve-i1-add-reduce.ll | 132 ++++++++++++++++++
2 files changed, 147 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/sve-i1-add-reduce.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e0c3cc5eddb82..f8c10c2ae82f6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -27469,6 +27469,21 @@ SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
}
+ // Lower VECREDUCE_ADD of nxv2i1-nxv16i1 to CNTP rather than UADDV.
+ if (ScalarOp.getOpcode() == ISD::VECREDUCE_ADD &&
+ VecOp.getOpcode() == ISD::ZERO_EXTEND) {
+ SDValue Vec = VecOp.getOperand(0);
+ EVT VecVT = Vec.getValueType();
+ if (VecVT.getVectorElementType() == MVT::i1) {
+ // CNTP(Vec & Vec) <=> CNTP(Vec & PTRUE)
+ SDValue CntpOp = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Vec,
+ Vec);
+ return DAG.getAnyExtOrTrunc(CntpOp, DL, ScalarOp.getValueType());
+ }
+ }
+
// UADDV always returns an i64 result.
EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 :
SrcVT.getVectorElementType();
diff --git a/llvm/test/CodeGen/AArch64/sve-i1-add-reduce.ll b/llvm/test/CodeGen/AArch64/sve-i1-add-reduce.ll
new file mode 100644
index 0000000000000..a748cf732e090
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-i1-add-reduce.ll
@@ -0,0 +1,132 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
+
+define i8 @uaddv_zexti8_nxv16i1(<vscale x 16 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti8_nxv16i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: cntp x0, p0, p0.b
+; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT: ret
+entry:
+ %3 = zext <vscale x 16 x i1> %v to <vscale x 16 x i8>
+ %4 = tail call i8 @llvm.vector.reduce.add.nxv16i8(<vscale x 16 x i8> %3)
+ ret i8 %4
+}
+
+define i8 @uaddv_zexti8_nxv8i1(<vscale x 8 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti8_nxv8i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: cntp x0, p0, p0.h
+; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT: ret
+entry:
+ %3 = zext <vscale x 8 x i1> %v to <vscale x 8 x i8>
+ %4 = tail call i8 @llvm.vector.reduce.add.nxv8i8(<vscale x 8 x i8> %3)
+ ret i8 %4
+}
+
+define i16 @uaddv_zexti16_nxv8i1(<vscale x 8 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti16_nxv8i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: cntp x0, p0, p0.h
+; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT: ret
+entry:
+ %3 = zext <vscale x 8 x i1> %v to <vscale x 8 x i16>
+ %4 = tail call i16 @llvm.vector.reduce.add.nxv8i16(<vscale x 8 x i16> %3)
+ ret i16 %4
+}
+
+define i8 @uaddv_zexti8_nxv4i1(<vscale x 4 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti8_nxv4i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: cntp x0, p0, p0.s
+; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT: ret
+entry:
+ %3 = zext <vscale x 4 x i1> %v to <vscale x 4 x i8>
+ %4 = tail call i8 @llvm.vector.reduce.add.nxv4i8(<vscale x 4 x i8> %3)
+ ret i8 %4
+}
+
+define i16 @uaddv_zexti16_nxv4i1(<vscale x 4 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti16_nxv4i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: cntp x0, p0, p0.s
+; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT: ret
+entry:
+ %3 = zext <vscale x 4 x i1> %v to <vscale x 4 x i16>
+ %4 = tail call i16 @llvm.vector.reduce.add.nxv4i16(<vscale x 4 x i16> %3)
+ ret i16 %4
+}
+
+define i32 @uaddv_zexti32_nxv4i1(<vscale x 4 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti32_nxv4i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: cntp x0, p0, p0.s
+; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT: ret
+entry:
+ %3 = zext <vscale x 4 x i1> %v to <vscale x 4 x i32>
+ %4 = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %3)
+ ret i32 %4
+}
+
+define i8 @uaddv_zexti8_nxv2i1(<vscale x 2 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti8_nxv2i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: cntp x0, p0, p0.d
+; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT: ret
+entry:
+ %3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i8>
+ %4 = tail call i8 @llvm.vector.reduce.add.nxv2i8(<vscale x 2 x i8> %3)
+ ret i8 %4
+}
+
+define i16 @uaddv_zexti16_nxv2i1(<vscale x 2 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti16_nxv2i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: cntp x0, p0, p0.d
+; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT: ret
+entry:
+ %3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i16>
+ %4 = tail call i16 @llvm.vector.reduce.add.nxv2i16(<vscale x 2 x i16> %3)
+ ret i16 %4
+}
+
+define i32 @uaddv_zexti32_nxv2i1(<vscale x 2 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti32_nxv2i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: cntp x0, p0, p0.d
+; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT: ret
+entry:
+ %3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i32>
+ %4 = tail call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> %3)
+ ret i32 %4
+}
+
+define i64 @uaddv_zexti64_nxv2i1(<vscale x 2 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti64_nxv2i1:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: cntp x0, p0, p0.d
+; CHECK-NEXT: ret
+entry:
+ %3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i64>
+ %4 = tail call i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64> %3)
+ ret i64 %4
+}
+
+declare i8 @llvm.vector.reduce.add.nxv16i8(<vscale x 16 x i8>)
+declare i8 @llvm.vector.reduce.add.nxv8i8(<vscale x 8 x i8>)
+declare i16 @llvm.vector.reduce.add.nxv8i16(<vscale x 8 x i16>)
+declare i8 @llvm.vector.reduce.add.nxv4i8(<vscale x 4 x i8>)
+declare i16 @llvm.vector.reduce.add.nxv4i16(<vscale x 4 x i16>)
+declare i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32>)
+declare i8 @llvm.vector.reduce.add.nxv2i8(<vscale x 2 x i8>)
+declare i16 @llvm.vector.reduce.add.nxv2i16(<vscale x 2 x i16>)
+declare i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32>)
+declare i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64>)
>From 0130bf6a39ed733ead628de473868b02969f9bb5 Mon Sep 17 00:00:00 2001
From: Max Beck-Jones <max.beck-jones at arm.com>
Date: Thu, 18 Jul 2024 09:56:14 +0000
Subject: [PATCH 2/4] fixup: Update vector name
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f8c10c2ae82f6..f2d57200c89d6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -27472,14 +27472,14 @@ SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
// Lower VECREDUCE_ADD of nxv2i1-nxv16i1 to CNTP rather than UADDV.
if (ScalarOp.getOpcode() == ISD::VECREDUCE_ADD &&
VecOp.getOpcode() == ISD::ZERO_EXTEND) {
- SDValue Vec = VecOp.getOperand(0);
- EVT VecVT = Vec.getValueType();
+ SDValue BoolVec = VecOp.getOperand(0);
+ EVT VecVT = BoolVec.getValueType();
if (VecVT.getVectorElementType() == MVT::i1) {
- // CNTP(Vec & Vec) <=> CNTP(Vec & PTRUE)
+ // CNTP(BoolVec & BoolVec) <=> CNTP(BoolVec & PTRUE)
SDValue CntpOp = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
- DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Vec,
- Vec);
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), BoolVec,
+ BoolVec);
return DAG.getAnyExtOrTrunc(CntpOp, DL, ScalarOp.getValueType());
}
}
>From 09025adec525441898484a0828be3674bc26c760 Mon Sep 17 00:00:00 2001
From: Max Beck-Jones <max.beck-jones at arm.com>
Date: Fri, 19 Jul 2024 10:19:51 +0000
Subject: [PATCH 3/4] fixup: Inline VecVT
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f2d57200c89d6..e8bf2f8c214f6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -27473,8 +27473,7 @@ SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
if (ScalarOp.getOpcode() == ISD::VECREDUCE_ADD &&
VecOp.getOpcode() == ISD::ZERO_EXTEND) {
SDValue BoolVec = VecOp.getOperand(0);
- EVT VecVT = BoolVec.getValueType();
- if (VecVT.getVectorElementType() == MVT::i1) {
+ if (BoolVec.getValueType().getVectorElementType() == MVT::i1) {
// CNTP(BoolVec & BoolVec) <=> CNTP(BoolVec & PTRUE)
SDValue CntpOp = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
>From 31be5831479c5e3a37737099a1542d04e21f6601 Mon Sep 17 00:00:00 2001
From: Max Beck-Jones <max.beck-jones at arm.com>
Date: Fri, 19 Jul 2024 11:02:43 +0000
Subject: [PATCH 4/4] fixup: Update formatting
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e8bf2f8c214f6..afb208e78a059 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -27477,8 +27477,8 @@ SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
// CNTP(BoolVec & BoolVec) <=> CNTP(BoolVec & PTRUE)
SDValue CntpOp = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
- DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), BoolVec,
- BoolVec);
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
+ BoolVec, BoolVec);
return DAG.getAnyExtOrTrunc(CntpOp, DL, ScalarOp.getValueType());
}
}
More information about the llvm-commits
mailing list