[llvm] [AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) -> PTEST_FIRST(PTRUE, A) (PR #161384)
Kerry McLaughlin via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 1 07:23:11 PDT 2025
https://github.com/kmclaughlin-arm updated https://github.com/llvm/llvm-project/pull/161384
>From a3167b5d438413715690df382fc0c3a0a3ecbef8 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Wed, 24 Sep 2025 09:02:09 +0000
Subject: [PATCH 1/3] [AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) ->
PTEST_FIRST(PTRUE, A)
When input to a ptest_first is a vector concat and the mask is all
active, performPTestFirstCombine returns a ptest_first using the
first operand of the concat, looking through any reinterpret casts
added by getPTest.
This allows optimizePTestInstr to later remove the ptest when the
first operand is a flag setting instruction such as whilelo.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 40 ++++++++++++++++++-
.../AArch64/get-active-lane-mask-extract.ll | 20 +---------
2 files changed, 40 insertions(+), 20 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 45f52352d45fd..1f641cd8c14fc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20370,7 +20370,7 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
}
static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
- AArch64CC::CondCode Cond);
+ AArch64CC::CondCode Cond, bool EmitCSel = true);
static bool isPredicateCCSettingOp(SDValue N) {
if ((N.getOpcode() == ISD::SETCC) ||
@@ -20495,6 +20495,7 @@ static SDValue
performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
+
if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget))
return Res;
if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
@@ -22535,7 +22536,7 @@ static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC,
}
static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
- AArch64CC::CondCode Cond) {
+ AArch64CC::CondCode Cond, bool EmitCSel) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDLoc DL(Op);
@@ -22568,6 +22569,8 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
// Set condition code (CC) flags.
SDValue Test = DAG.getNode(PTest, DL, MVT::i32, Pg, Op);
+ if (!EmitCSel)
+ return Test;
// Convert CC to integer based on requested condition.
// NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
@@ -27519,6 +27522,37 @@ static SDValue performMULLCombine(SDNode *N,
return SDValue();
}
+static SDValue performPTestFirstCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ SDLoc DL(N);
+ auto Mask = N->getOperand(0);
+ auto Pred = N->getOperand(1);
+
+ if (Mask->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ Mask = Mask->getOperand(0);
+
+ if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ Pred = Pred->getOperand(0);
+
+ if (Pred->getValueType(0).getVectorElementType() != MVT::i1 ||
+ !isAllActivePredicate(DAG, Mask))
+ return SDValue();
+
+ if (Pred->getOpcode() == ISD::CONCAT_VECTORS) {
+ Pred = Pred->getOperand(0);
+ SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL,
+ DAG.getAllOnesConstant(DL, MVT::i64));
+ return getPTest(DAG, N->getValueType(0), Mask, Pred,
+ AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false);
+ }
+
+ return SDValue();
+}
+
static SDValue
performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
@@ -27875,6 +27909,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case AArch64ISD::UMULL:
case AArch64ISD::PMULL:
return performMULLCombine(N, DCI, DAG);
+ case AArch64ISD::PTEST_FIRST:
+ return performPTestFirstCombine(N, DCI, DAG);
case ISD::INTRINSIC_VOID:
case ISD::INTRINSIC_W_CHAIN:
switch (N->getConstantOperandVal(1)) {
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index b89f55188b0f2..e2c861b40e706 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -327,9 +327,6 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b
-; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -368,9 +365,6 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h
-; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -413,14 +407,9 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b
-; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b
-; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
@@ -463,14 +452,9 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
-; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s
-; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h
-; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h
-; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
>From b99f48c44bbf8aeb126ebbafaadc0e09eef4769c Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Wed, 1 Oct 2025 10:22:23 +0000
Subject: [PATCH 2/3] - Add isLane1KnownActive helper - Extend
canRemovePTestInstr for PTEST_PP_FIRST - Remove getVectorElementType() !=
MVT::i1 check for Mask
---
.../Target/AArch64/AArch64ISelLowering.cpp | 35 +++++++++++--------
llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 4 +++
2 files changed, 25 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1f641cd8c14fc..20b8e75040512 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20370,7 +20370,7 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
}
static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
- AArch64CC::CondCode Cond, bool EmitCSel = true);
+ AArch64CC::CondCode Cond);
static bool isPredicateCCSettingOp(SDValue N) {
if ((N.getOpcode() == ISD::SETCC) ||
@@ -20495,7 +20495,6 @@ static SDValue
performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
-
if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget))
return Res;
if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
@@ -22536,7 +22535,7 @@ static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC,
}
static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
- AArch64CC::CondCode Cond, bool EmitCSel) {
+ AArch64CC::CondCode Cond) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDLoc DL(Op);
@@ -22569,8 +22568,6 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
// Set condition code (CC) flags.
SDValue Test = DAG.getNode(PTest, DL, MVT::i32, Pg, Op);
- if (!EmitCSel)
- return Test;
// Convert CC to integer based on requested condition.
// NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
@@ -27237,6 +27234,21 @@ static bool isLanes1toNKnownZero(SDValue Op) {
}
}
+// Return true if the vector operation can guarantee that the first lane of its
+// result is active.
+static bool isLane1KnownActive(SDValue Op) {
+ switch (Op.getOpcode()) {
+ default:
+ return false;
+ case AArch64ISD::REINTERPRET_CAST:
+ return isLane1KnownActive(Op->getOperand(0));
+ case ISD::SPLAT_VECTOR:
+ return isOneConstant(Op.getOperand(0));
+ case AArch64ISD::PTRUE:
+ return Op.getConstantOperandVal(0) == AArch64SVEPredPattern::all;
+ };
+}
+
static SDValue removeRedundantInsertVectorElt(SDNode *N) {
assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT && "Unexpected node!");
SDValue InsertVec = N->getOperand(0);
@@ -27532,22 +27544,17 @@ static SDValue performPTestFirstCombine(SDNode *N,
auto Mask = N->getOperand(0);
auto Pred = N->getOperand(1);
- if (Mask->getOpcode() == AArch64ISD::REINTERPRET_CAST)
- Mask = Mask->getOperand(0);
-
if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
Pred = Pred->getOperand(0);
- if (Pred->getValueType(0).getVectorElementType() != MVT::i1 ||
- !isAllActivePredicate(DAG, Mask))
+ if (!isLane1KnownActive(Mask))
return SDValue();
if (Pred->getOpcode() == ISD::CONCAT_VECTORS) {
Pred = Pred->getOperand(0);
- SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL,
- DAG.getAllOnesConstant(DL, MVT::i64));
- return getPTest(DAG, N->getValueType(0), Mask, Pred,
- AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false);
+ Pred = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pred);
+ return DAG.getNode(AArch64ISD::PTEST_FIRST, DL, N->getValueType(0), Mask,
+ Pred);
}
return SDValue();
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 5a51c812732e6..ef056002f085c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1503,6 +1503,10 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
getElementSizeForOpcode(PredOpcode))
return PredOpcode;
+ if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST &&
+ isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31)
+ return PredOpcode;
+
return {};
}
>From c5402afee5b5578a44ec7714e234efae05410ed2 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Wed, 1 Oct 2025 14:18:04 +0000
Subject: [PATCH 3/3] - Rename isLane1KnownActive -> isLane0KnownActive
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 20b8e75040512..cc4baa474f99b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -27236,12 +27236,12 @@ static bool isLanes1toNKnownZero(SDValue Op) {
// Return true if the vector operation can guarantee that the first lane of its
// result is active.
-static bool isLane1KnownActive(SDValue Op) {
+static bool isLane0KnownActive(SDValue Op) {
switch (Op.getOpcode()) {
default:
return false;
case AArch64ISD::REINTERPRET_CAST:
- return isLane1KnownActive(Op->getOperand(0));
+ return isLane0KnownActive(Op->getOperand(0));
case ISD::SPLAT_VECTOR:
return isOneConstant(Op.getOperand(0));
case AArch64ISD::PTRUE:
@@ -27547,7 +27547,7 @@ static SDValue performPTestFirstCombine(SDNode *N,
if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
Pred = Pred->getOperand(0);
- if (!isLane1KnownActive(Mask))
+ if (!isLane0KnownActive(Mask))
return SDValue();
if (Pred->getOpcode() == ISD::CONCAT_VECTORS) {
More information about the llvm-commits
mailing list