[llvm] 4db11c1 - [AArch64] Lower scalable i1 vector add reduction to cntp (#99031)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 22 02:14:31 PDT 2024


Author: Max Beck-Jones
Date: 2024-07-22T10:14:28+01:00
New Revision: 4db11c1f6cd6cd12b51a3220a54697b90e2e8821

URL: https://github.com/llvm/llvm-project/commit/4db11c1f6cd6cd12b51a3220a54697b90e2e8821
DIFF: https://github.com/llvm/llvm-project/commit/4db11c1f6cd6cd12b51a3220a54697b90e2e8821.diff

LOG: [AArch64] Lower scalable i1 vector add reduction to cntp (#99031)

Doing an add reduction on a vector of i1 elements is the same as
counting the number of set elements so such a reduction can be lowered
to a cntp instruction. This saves a number of instructions over
performing a UADDV. This patch only handles straightforward cases (i.e.
when vectors are not split).

Added: 
    llvm/test/CodeGen/AArch64/sve-i1-add-reduce.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index bf205b1706a6c..c11855da3fae0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -27640,6 +27640,20 @@ SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
     VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
   }
 
+  // Lower VECREDUCE_ADD of nxv2i1-nxv16i1 to CNTP rather than UADDV.
+  if (ScalarOp.getOpcode() == ISD::VECREDUCE_ADD &&
+      VecOp.getOpcode() == ISD::ZERO_EXTEND) {
+    SDValue BoolVec = VecOp.getOperand(0);
+    if (BoolVec.getValueType().getVectorElementType() == MVT::i1) {
+      // CNTP(BoolVec & BoolVec) <=> CNTP(BoolVec & PTRUE)
+      SDValue CntpOp = DAG.getNode(
+          ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
+          DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
+          BoolVec, BoolVec);
+      return DAG.getAnyExtOrTrunc(CntpOp, DL, ScalarOp.getValueType());
+    }
+  }
+
   // UADDV always returns an i64 result.
   EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 :
                                                    SrcVT.getVectorElementType();

diff  --git a/llvm/test/CodeGen/AArch64/sve-i1-add-reduce.ll b/llvm/test/CodeGen/AArch64/sve-i1-add-reduce.ll
new file mode 100644
index 0000000000000..a748cf732e090
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-i1-add-reduce.ll
@@ -0,0 +1,132 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
+
+define i8 @uaddv_zexti8_nxv16i1(<vscale x 16 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti8_nxv16i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    cntp x0, p0, p0.b
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+entry:
+  %3 = zext <vscale x 16 x i1> %v to <vscale x 16 x i8>
+  %4 = tail call i8 @llvm.vector.reduce.add.nxv16i8(<vscale x 16 x i8> %3)
+  ret i8 %4
+}
+
+define i8 @uaddv_zexti8_nxv8i1(<vscale x 8 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti8_nxv8i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    cntp x0, p0, p0.h
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+entry:
+  %3 = zext <vscale x 8 x i1> %v to <vscale x 8 x i8>
+  %4 = tail call i8 @llvm.vector.reduce.add.nxv8i8(<vscale x 8 x i8> %3)
+  ret i8 %4
+}
+
+define i16 @uaddv_zexti16_nxv8i1(<vscale x 8 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti16_nxv8i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    cntp x0, p0, p0.h
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+entry:
+  %3 = zext <vscale x 8 x i1> %v to <vscale x 8 x i16>
+  %4 = tail call i16 @llvm.vector.reduce.add.nxv8i16(<vscale x 8 x i16> %3)
+  ret i16 %4
+}
+
+define i8 @uaddv_zexti8_nxv4i1(<vscale x 4 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti8_nxv4i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    cntp x0, p0, p0.s
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+entry:
+  %3 = zext <vscale x 4 x i1> %v to <vscale x 4 x i8>
+  %4 = tail call i8 @llvm.vector.reduce.add.nxv4i8(<vscale x 4 x i8> %3)
+  ret i8 %4
+}
+
+define i16 @uaddv_zexti16_nxv4i1(<vscale x 4 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti16_nxv4i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    cntp x0, p0, p0.s
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+entry:
+  %3 = zext <vscale x 4 x i1> %v to <vscale x 4 x i16>
+  %4 = tail call i16 @llvm.vector.reduce.add.nxv4i16(<vscale x 4 x i16> %3)
+  ret i16 %4
+}
+
+define i32 @uaddv_zexti32_nxv4i1(<vscale x 4 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti32_nxv4i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    cntp x0, p0, p0.s
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+entry:
+  %3 = zext <vscale x 4 x i1> %v to <vscale x 4 x i32>
+  %4 = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %3)
+  ret i32 %4
+}
+
+define i8 @uaddv_zexti8_nxv2i1(<vscale x 2 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti8_nxv2i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    cntp x0, p0, p0.d
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+entry:
+  %3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i8>
+  %4 = tail call i8 @llvm.vector.reduce.add.nxv2i8(<vscale x 2 x i8> %3)
+  ret i8 %4
+}
+
+define i16 @uaddv_zexti16_nxv2i1(<vscale x 2 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti16_nxv2i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    cntp x0, p0, p0.d
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+entry:
+  %3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i16>
+  %4 = tail call i16 @llvm.vector.reduce.add.nxv2i16(<vscale x 2 x i16> %3)
+  ret i16 %4
+}
+
+define i32 @uaddv_zexti32_nxv2i1(<vscale x 2 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti32_nxv2i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    cntp x0, p0, p0.d
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
+entry:
+  %3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i32>
+  %4 = tail call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> %3)
+  ret i32 %4
+}
+
+define i64 @uaddv_zexti64_nxv2i1(<vscale x 2 x i1> %v) {
+; CHECK-LABEL: uaddv_zexti64_nxv2i1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    cntp x0, p0, p0.d
+; CHECK-NEXT:    ret
+entry:
+  %3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i64>
+  %4 = tail call i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64> %3)
+  ret i64 %4
+}
+
+declare i8 @llvm.vector.reduce.add.nxv16i8(<vscale x 16 x i8>)
+declare i8 @llvm.vector.reduce.add.nxv8i8(<vscale x 8 x i8>)
+declare i16 @llvm.vector.reduce.add.nxv8i16(<vscale x 8 x i16>)
+declare i8 @llvm.vector.reduce.add.nxv4i8(<vscale x 4 x i8>)
+declare i16 @llvm.vector.reduce.add.nxv4i16(<vscale x 4 x i16>)
+declare i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32>)
+declare i8 @llvm.vector.reduce.add.nxv2i8(<vscale x 2 x i8>)
+declare i16 @llvm.vector.reduce.add.nxv2i16(<vscale x 2 x i16>)
+declare i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32>)
+declare i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64>)


        


More information about the llvm-commits mailing list