[llvm] 20e0e80 - [AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) -> PTEST_FIRST(PTRUE, A) (#161384)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 2 02:13:11 PDT 2025


Author: Kerry McLaughlin
Date: 2025-10-02T10:13:06+01:00
New Revision: 20e0e80a540223194e06d5e593634f65e1ee0de8

URL: https://github.com/llvm/llvm-project/commit/20e0e80a540223194e06d5e593634f65e1ee0de8
DIFF: https://github.com/llvm/llvm-project/commit/20e0e80a540223194e06d5e593634f65e1ee0de8.diff

LOG: [AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) -> PTEST_FIRST(PTRUE, A) (#161384)

When the input to ptest_first is a vector concat and the mask is all active,
performPTestFirstCombine returns a ptest_first using the first operand
of the concat, looking through any reinterpret casts.

This allows optimizePTestInstr to later remove the ptest when the first
operand is a flag setting instruction such as whilelo.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
    llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 45f52352d45fd..a1f4734f83562 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -27234,6 +27234,21 @@ static bool isLanes1toNKnownZero(SDValue Op) {
   }
 }
 
+// Return true if the vector operation can guarantee that the first lane of its
+// result is active.
+static bool isLane0KnownActive(SDValue Op) {
+  switch (Op.getOpcode()) {
+  default:
+    return false;
+  case AArch64ISD::REINTERPRET_CAST:
+    return isLane0KnownActive(Op->getOperand(0));
+  case ISD::SPLAT_VECTOR:
+    return isOneConstant(Op.getOperand(0));
+  case AArch64ISD::PTRUE:
+    return Op.getConstantOperandVal(0) == AArch64SVEPredPattern::all;
+  };
+}
+
 static SDValue removeRedundantInsertVectorElt(SDNode *N) {
   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT && "Unexpected node!");
   SDValue InsertVec = N->getOperand(0);
@@ -27519,6 +27534,32 @@ static SDValue performMULLCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performPTestFirstCombine(SDNode *N,
+                                        TargetLowering::DAGCombinerInfo &DCI,
+                                        SelectionDAG &DAG) {
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  SDLoc DL(N);
+  auto Mask = N->getOperand(0);
+  auto Pred = N->getOperand(1);
+
+  if (!isLane0KnownActive(Mask))
+    return SDValue();
+
+  if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    Pred = Pred->getOperand(0);
+
+  if (Pred->getOpcode() == ISD::CONCAT_VECTORS) {
+    Pred = Pred->getOperand(0);
+    Pred = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pred);
+    return DAG.getNode(AArch64ISD::PTEST_FIRST, DL, N->getValueType(0), Mask,
+                       Pred);
+  }
+
+  return SDValue();
+}
+
 static SDValue
 performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                              SelectionDAG &DAG) {
@@ -27875,6 +27916,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case AArch64ISD::UMULL:
   case AArch64ISD::PMULL:
     return performMULLCombine(N, DCI, DAG);
+  case AArch64ISD::PTEST_FIRST:
+    return performPTestFirstCombine(N, DCI, DAG);
   case ISD::INTRINSIC_VOID:
   case ISD::INTRINSIC_W_CHAIN:
     switch (N->getConstantOperandVal(1)) {

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 5a51c812732e6..35b27ea2ec9dd 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1503,6 +1503,13 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
             getElementSizeForOpcode(PredOpcode))
       return PredOpcode;
 
+    // For PTEST_FIRST(PTRUE_ALL, WHILE), the PTEST_FIRST is redundant since
+    // WHILEcc performs an implicit PTEST with an all active mask, setting
+    // the N flag as the PTEST_FIRST would.
+    if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST &&
+        isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31)
+      return PredOpcode;
+
     return {};
   }
 

diff  --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index b89f55188b0f2..e2c861b40e706 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -327,9 +327,6 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
 ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
 ; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.h, p1.h }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p2.b
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p3.b, p0.b, p1.b
-; CHECK-SVE2p1-SME2-NEXT:    ptest p2, p3.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB11_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
 ; CHECK-SVE2p1-SME2-NEXT:    b use
@@ -368,9 +365,6 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
 ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
 ; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p2.h
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p3.h, p0.h, p1.h
-; CHECK-SVE2p1-SME2-NEXT:    ptest p2, p3.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB12_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
 ; CHECK-SVE2p1-SME2-NEXT:    b use
@@ -413,14 +407,9 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
 ; CHECK-SVE2p1-SME2-NEXT:    adds x8, x0, x8
 ; CHECK-SVE2p1-SME2-NEXT:    csinv x8, x8, xzr, lo
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.s, p3.s }, x8, x1
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.h, p0.h, p1.h
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p5.h, p2.h, p3.h
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.b, p4.b, p5.b
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p5.b
-; CHECK-SVE2p1-SME2-NEXT:    ptest p5, p4.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB13_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.s, p3.s }, x8, x1
 ; CHECK-SVE2p1-SME2-NEXT:    b use
 ; CHECK-SVE2p1-SME2-NEXT:  .LBB13_2: // %if.end
 ; CHECK-SVE2p1-SME2-NEXT:    ret
@@ -463,14 +452,9 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
 ; CHECK-SVE2p1-SME2-NEXT:    adds x8, x0, x8
 ; CHECK-SVE2p1-SME2-NEXT:    csinv x8, x8, xzr, lo
 ; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.d, p1.d }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.d, p3.d }, x8, x1
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.s, p0.s, p1.s
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p5.s, p2.s, p3.s
-; CHECK-SVE2p1-SME2-NEXT:    uzp1 p4.h, p4.h, p5.h
-; CHECK-SVE2p1-SME2-NEXT:    ptrue p5.h
-; CHECK-SVE2p1-SME2-NEXT:    ptest p5, p4.b
 ; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB14_2
 ; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p2.d, p3.d }, x8, x1
 ; CHECK-SVE2p1-SME2-NEXT:    b use
 ; CHECK-SVE2p1-SME2-NEXT:  .LBB14_2: // %if.end
 ; CHECK-SVE2p1-SME2-NEXT:    ret


        


More information about the llvm-commits mailing list