[llvm] [AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) -> PTEST_FIRST(PTRUE, A) (PR #161384)

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 30 07:31:41 PDT 2025


https://github.com/kmclaughlin-arm created https://github.com/llvm/llvm-project/pull/161384

When the input to ptest_first is a vector concat and the mask is all active,
performPTestFirstCombine returns a ptest_first using the first operand
of the concat, looking through any reinterpret casts.

This allows optimizePTestInstr to later remove the ptest when the first
operand is a flag setting instruction such as whilelo.

>From a3167b5d438413715690df382fc0c3a0a3ecbef8 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Wed, 24 Sep 2025 09:02:09 +0000
Subject: [PATCH] [AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) ->
 PTEST_FIRST(PTRUE, A)

When input to a ptest_first is a vector concat and the mask is all
active, performPTestFirstCombine returns a ptest_first using the
first operand of the concat, looking through any reinterpret casts
added by getPTest.

This allows optimizePTestInstr to later remove the ptest when the
first operand is a flag setting instruction such as whilelo.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 40 ++++++++++++++++++-
 .../AArch64/get-active-lane-mask-extract.ll   | 20 +---------
 2 files changed, 40 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 45f52352d45fd..1f641cd8c14fc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20370,7 +20370,7 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
 }
 
 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
-                        AArch64CC::CondCode Cond);
+                        AArch64CC::CondCode Cond, bool EmitCSel = true);
 
 static bool isPredicateCCSettingOp(SDValue N) {
   if ((N.getOpcode() == ISD::SETCC) ||
@@ -20495,6 +20495,7 @@ static SDValue
 performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                const AArch64Subtarget *Subtarget) {
   assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
+
   if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget))
     return Res;
   if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
@@ -22535,7 +22536,7 @@ static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC,
 }
 
 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
-                        AArch64CC::CondCode Cond) {
+                        AArch64CC::CondCode Cond, bool EmitCSel) {
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   SDLoc DL(Op);
@@ -22568,6 +22569,8 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
 
   // Set condition code (CC) flags.
   SDValue Test = DAG.getNode(PTest, DL, MVT::i32, Pg, Op);
+  if (!EmitCSel)
+    return Test;
 
   // Convert CC to integer based on requested condition.
   // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
@@ -27519,6 +27522,37 @@ static SDValue performMULLCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performPTestFirstCombine(SDNode *N,
+                                        TargetLowering::DAGCombinerInfo &DCI,
+                                        SelectionDAG &DAG) {
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  SDLoc DL(N);
+  auto Mask = N->getOperand(0);
+  auto Pred = N->getOperand(1);
+
+  if (Mask->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    Mask = Mask->getOperand(0);
+
+  if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    Pred = Pred->getOperand(0);
+
+  if (Pred->getValueType(0).getVectorElementType() != MVT::i1 ||
+      !isAllActivePredicate(DAG, Mask))
+    return SDValue();
+
+  if (Pred->getOpcode() == ISD::CONCAT_VECTORS) {
+    Pred = Pred->getOperand(0);
+    SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL,
+                                      DAG.getAllOnesConstant(DL, MVT::i64));
+    return getPTest(DAG, N->getValueType(0), Mask, Pred,
+                    AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false);
+  }
+
+  return SDValue();
+}
+
 static SDValue
 performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                              SelectionDAG &DAG) {
@@ -27875,6 +27909,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case AArch64ISD::UMULL:
   case AArch64ISD::PMULL:
     return performMULLCombine(N, DCI, DAG);
+  case AArch64ISD::PTEST_FIRST:
+    return performPTestFirstCombine(N, DCI, DAG);
   case ISD::INTRINSIC_VOID:
   case ISD::INTRINSIC_W_CHAIN:
     switch (N->getConstantOperandVal(1)) {
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index b89f55188b0f2..e2c861b40e706 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -327,9 +327,6 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
 ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
 ; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.h, p1.h }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p2.b
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p3.b, p0.b, p1.b
-; CHECK-SVE2p1-SME2-NEXT:    ptest p2, p3.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB11_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
 ; CHECK-SVE2p1-SME2-NEXT:    b use
@@ -368,9 +365,6 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
 ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
 ; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p2.h
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p3.h, p0.h, p1.h
-; CHECK-SVE2p1-SME2-NEXT:    ptest p2, p3.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB12_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
 ; CHECK-SVE2p1-SME2-NEXT:    b use
@@ -413,14 +407,9 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
 ; CHECK-SVE2p1-SME2-NEXT:    adds x8, x0, x8
 ; CHECK-SVE2p1-SME2-NEXT:    csinv x8, x8, xzr, lo
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.s, p3.s }, x8, x1
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.h, p0.h, p1.h
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p5.h, p2.h, p3.h
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.b, p4.b, p5.b
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p5.b
-; CHECK-SVE2p1-SME2-NEXT:    ptest p5, p4.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB13_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.s, p3.s }, x8, x1
 ; CHECK-SVE2p1-SME2-NEXT:    b use
 ; CHECK-SVE2p1-SME2-NEXT:  .LBB13_2: // %if.end
 ; CHECK-SVE2p1-SME2-NEXT:    ret
@@ -463,14 +452,9 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
 ; CHECK-SVE2p1-SME2-NEXT:    adds x8, x0, x8
 ; CHECK-SVE2p1-SME2-NEXT:    csinv x8, x8, xzr, lo
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.d, p1.d }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.d, p3.d }, x8, x1
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.s, p0.s, p1.s
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p5.s, p2.s, p3.s
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.h, p4.h, p5.h
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p5.h
-; CHECK-SVE2p1-SME2-NEXT:    ptest p5, p4.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB14_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.d, p3.d }, x8, x1
 ; CHECK-SVE2p1-SME2-NEXT:    b use
 ; CHECK-SVE2p1-SME2-NEXT:  .LBB14_2: // %if.end
 ; CHECK-SVE2p1-SME2-NEXT:    ret



More information about the llvm-commits mailing list