[llvm] [AArch64] Optimise test of the LSB of a paired whileCC instruction (PR #81141)
Momchil Velikov via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 12 07:20:38 PDT 2024
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/81141
>From 2f5e80c5d218c10a480b0c816883e8d93d359b31 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 11 Jun 2024 14:07:17 +0100
Subject: [PATCH 1/3] [AArch64] Refactor redundant PTEST optimisations (NFC)
Change-Id: I63ff6f4a7f90cd584508cbaa8bba8a39a8ca3f56
---
llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 171 ++++++++++---------
llvm/lib/Target/AArch64/AArch64InstrInfo.h | 3 +
2 files changed, 96 insertions(+), 78 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index a5135b78bded9..82061607f2e53 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1355,48 +1355,52 @@ static bool areCFlagsAccessedBetweenInstrs(
return false;
}
-/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
-/// operation which could set the flags in an identical manner
-bool AArch64InstrInfo::optimizePTestInstr(
- MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
- const MachineRegisterInfo *MRI) const {
- auto *Mask = MRI->getUniqueVRegDef(MaskReg);
- auto *Pred = MRI->getUniqueVRegDef(PredReg);
- auto NewOp = Pred->getOpcode();
- bool OpChanged = false;
-
+std::optional<unsigned>
+AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+ MachineInstr *Pred,
+ const MachineRegisterInfo *MRI) const {
unsigned MaskOpcode = Mask->getOpcode();
unsigned PredOpcode = Pred->getOpcode();
bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
bool PredIsWhileLike = isWhileOpcode(PredOpcode);
- if (isPTrueOpcode(MaskOpcode) && (PredIsPTestLike || PredIsWhileLike) &&
- getElementSizeForOpcode(MaskOpcode) ==
- getElementSizeForOpcode(PredOpcode) &&
- Mask->getOperand(1).getImm() == 31) {
+ if (PredIsWhileLike) {
+ // For PTEST(PG, PG), PTEST is redundant when PG is the result of a WHILEcc
+ // instruction and the condition is "any" since WHILcc does an implicit
+ // PTEST(ALL, PG) check and PG is always a subset of ALL.
+ if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+ return PredOpcode;
+
// For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
// redundant since WHILE performs an implicit PTEST with an all active
- // mask. Must be an all active predicate of matching element size.
+ // mask.
+ if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+ getElementSizeForOpcode(MaskOpcode) ==
+ getElementSizeForOpcode(PredOpcode))
+ return PredOpcode;
+
+ return {};
+ }
+
+ if (PredIsPTestLike) {
+ // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
+ // instruction that sets the flags as PTEST would and the condition is
+ // "any" since PG is always a subset of the governing predicate of the
+ // ptest-like instruction.
+ if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+ return PredOpcode;
// For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
- // PTEST_LIKE instruction uses the same all active mask and the element
- // size matches. If the PTEST has a condition of any then it is always
- // redundant.
- if (PredIsPTestLike) {
+ // the element size matches and either the PTEST_LIKE instruction uses
+ // the same all active mask or the condition is "any".
+ if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+ getElementSizeForOpcode(MaskOpcode) ==
+ getElementSizeForOpcode(PredOpcode)) {
auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
- if (Mask != PTestLikeMask && PTest->getOpcode() != AArch64::PTEST_PP_ANY)
- return false;
+ if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+ return PredOpcode;
}
- // Fallthough to simply remove the PTEST.
- } else if ((Mask == Pred) && (PredIsPTestLike || PredIsWhileLike) &&
- PTest->getOpcode() == AArch64::PTEST_PP_ANY) {
- // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
- // instruction that sets the flags as PTEST would. This is only valid when
- // the condition is any.
-
- // Fallthough to simply remove the PTEST.
- } else if (PredIsPTestLike) {
// For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
// flags are set based on the same mask 'PG', but PTEST_LIKE must operate
// on 8-bit predicates like the PTEST. Otherwise, for instructions like
@@ -1421,55 +1425,66 @@ bool AArch64InstrInfo::optimizePTestInstr(
// identical regardless of element size.
auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
- if ((Mask != PTestLikeMask) ||
- (PredElementSize != AArch64::ElementSizeB &&
- PTest->getOpcode() != AArch64::PTEST_PP_ANY))
- return false;
+ if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
+ PTest->getOpcode() == AArch64::PTEST_PP_ANY))
+ return PredOpcode;
- // Fallthough to simply remove the PTEST.
- } else {
- // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
- // opcode so the PTEST becomes redundant.
- switch (PredOpcode) {
- case AArch64::AND_PPzPP:
- case AArch64::BIC_PPzPP:
- case AArch64::EOR_PPzPP:
- case AArch64::NAND_PPzPP:
- case AArch64::NOR_PPzPP:
- case AArch64::ORN_PPzPP:
- case AArch64::ORR_PPzPP:
- case AArch64::BRKA_PPzP:
- case AArch64::BRKPA_PPzPP:
- case AArch64::BRKB_PPzP:
- case AArch64::BRKPB_PPzPP:
- case AArch64::RDFFR_PPz: {
- // Check to see if our mask is the same. If not the resulting flag bits
- // may be different and we can't remove the ptest.
- auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
- if (Mask != PredMask)
- return false;
- break;
- }
- case AArch64::BRKN_PPzP: {
- // BRKN uses an all active implicit mask to set flags unlike the other
- // flag-setting instructions.
- // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
- if ((MaskOpcode != AArch64::PTRUE_B) ||
- (Mask->getOperand(1).getImm() != 31))
- return false;
- break;
- }
- case AArch64::PTRUE_B:
- // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
- break;
- default:
- // Bail out if we don't recognize the input
- return false;
- }
+ return {};
+ }
- NewOp = convertToFlagSettingOpc(PredOpcode);
- OpChanged = true;
+ // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
+ // opcode so the PTEST becomes redundant.
+ switch (PredOpcode) {
+ case AArch64::AND_PPzPP:
+ case AArch64::BIC_PPzPP:
+ case AArch64::EOR_PPzPP:
+ case AArch64::NAND_PPzPP:
+ case AArch64::NOR_PPzPP:
+ case AArch64::ORN_PPzPP:
+ case AArch64::ORR_PPzPP:
+ case AArch64::BRKA_PPzP:
+ case AArch64::BRKPA_PPzPP:
+ case AArch64::BRKB_PPzP:
+ case AArch64::BRKPB_PPzPP:
+ case AArch64::RDFFR_PPz: {
+ // Check to see if our mask is the same. If not the resulting flag bits
+ // may be different and we can't remove the ptest.
+ auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+ if (Mask != PredMask)
+ return {};
+ break;
}
+ case AArch64::BRKN_PPzP: {
+ // BRKN uses an all active implicit mask to set flags unlike the other
+ // flag-setting instructions.
+ // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
+ if ((MaskOpcode != AArch64::PTRUE_B) ||
+ (Mask->getOperand(1).getImm() != 31))
+ return {};
+ break;
+ }
+ case AArch64::PTRUE_B:
+ // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
+ break;
+ default:
+ // Bail out if we don't recognize the input
+ return {};
+ }
+
+ return convertToFlagSettingOpc(PredOpcode);
+}
+
+/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
+/// operation which could set the flags in an identical manner
+bool AArch64InstrInfo::optimizePTestInstr(
+ MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
+ const MachineRegisterInfo *MRI) const {
+ auto *Mask = MRI->getUniqueVRegDef(MaskReg);
+ auto *Pred = MRI->getUniqueVRegDef(PredReg);
+ unsigned PredOpcode = Pred->getOpcode();
+ auto NewOp = canRemovePTestInstr(PTest, Mask, Pred, MRI);
+ if (!NewOp)
+ return false;
const TargetRegisterInfo *TRI = &getRegisterInfo();
@@ -1482,9 +1497,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
// as they are prior to PTEST. Sometimes this requires the tested PTEST
// operand to be replaced with an equivalent instruction that also sets the
// flags.
- Pred->setDesc(get(NewOp));
PTest->eraseFromParent();
- if (OpChanged) {
+ if (*NewOp != PredOpcode) {
+ Pred->setDesc(get(*NewOp));
bool succeeded = UpdateOperandRegClass(*Pred);
(void)succeeded;
assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index f434799c3982b..792e0c3063b10 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -572,6 +572,9 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
unsigned PredReg,
const MachineRegisterInfo *MRI) const;
+ std::optional<unsigned>
+ canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+ MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
};
struct UsedNZCV {
>From 3a79e1d9c1a00f08e3661722571720723a01dd74 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 11 Jun 2024 15:51:15 +0100
Subject: [PATCH 2/3] [AArch64] Precommit testcase for optimised test of the
LSB of a paired whileCC instruction
Change-Id: I5058e24c631ede0a04399b39e5096f898fa8f792
---
llvm/test/CodeGen/AArch64/opt-while-test.ll | 107 ++++++++++++++++++++
1 file changed, 107 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/opt-while-test.ll
diff --git a/llvm/test/CodeGen/AArch64/opt-while-test.ll b/llvm/test/CodeGen/AArch64/opt-while-test.ll
new file mode 100644
index 0000000000000..3a43f817d6154
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/opt-while-test.ll
@@ -0,0 +1,107 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s --check-prefix=CHECK-SVE2p1
+target triple = "aarch64-linux"
+
+define void @f_while(i32 %i, i32 %n) #0 {
+; CHECK-LABEL: f_while:
+; CHECK: // %bb.0: // %E
+; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: whilelo p0.b, w0, w1
+; CHECK-NEXT: punpklo p0.h, p0.b
+; CHECK-NEXT: mov z0.h, p0/z, #1 // =0x1
+; CHECK-NEXT: fmov w8, s0
+; CHECK-NEXT: tbz w8, #0, .LBB0_2
+; CHECK-NEXT: // %bb.1: // %A
+; CHECK-NEXT: bl g0
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+; CHECK-NEXT: .LBB0_2: // %B
+; CHECK-NEXT: bl g1
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+;
+; CHECK-SVE2p1-LABEL: f_while:
+; CHECK-SVE2p1: // %bb.0: // %E
+; CHECK-SVE2p1-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SVE2p1-NEXT: whilelo p0.b, w0, w1
+; CHECK-SVE2p1-NEXT: punpklo p0.h, p0.b
+; CHECK-SVE2p1-NEXT: mov z0.h, p0/z, #1 // =0x1
+; CHECK-SVE2p1-NEXT: fmov w8, s0
+; CHECK-SVE2p1-NEXT: tbz w8, #0, .LBB0_2
+; CHECK-SVE2p1-NEXT: // %bb.1: // %A
+; CHECK-SVE2p1-NEXT: bl g0
+; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT: ret
+; CHECK-SVE2p1-NEXT: .LBB0_2: // %B
+; CHECK-SVE2p1-NEXT: bl g1
+; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT: ret
+E:
+ %wide.mask = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i32 %i, i32 %n)
+ %mask = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1(<vscale x 16 x i1> %wide.mask, i64 0)
+ %elt = extractelement <vscale x 8 x i1> %mask, i64 0
+ br i1 %elt, label %A, label %B
+A:
+ call void @g0()
+ ret void
+B:
+ call void @g1()
+ ret void
+}
+
+define void @f_while_x2(i32 %i, i32 %n) #0 {
+; CHECK-LABEL: f_while_x2:
+; CHECK: // %bb.0: // %E
+; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: whilelo p1.b, w0, w1
+; CHECK-NEXT: punpkhi p0.h, p1.b
+; CHECK-NEXT: punpklo p1.h, p1.b
+; CHECK-NEXT: mov z0.h, p1/z, #1 // =0x1
+; CHECK-NEXT: fmov w8, s0
+; CHECK-NEXT: tbz w8, #0, .LBB1_2
+; CHECK-NEXT: // %bb.1: // %A
+; CHECK-NEXT: bl g0
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+; CHECK-NEXT: .LBB1_2: // %B
+; CHECK-NEXT: bl g1
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+;
+; CHECK-SVE2p1-LABEL: f_while_x2:
+; CHECK-SVE2p1: // %bb.0: // %E
+; CHECK-SVE2p1-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SVE2p1-NEXT: mov w8, w1
+; CHECK-SVE2p1-NEXT: mov w9, w0
+; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SVE2p1-NEXT: mov z0.h, p0/z, #1 // =0x1
+; CHECK-SVE2p1-NEXT: mov p0.b, p1.b
+; CHECK-SVE2p1-NEXT: fmov w8, s0
+; CHECK-SVE2p1-NEXT: tbz w8, #0, .LBB1_2
+; CHECK-SVE2p1-NEXT: // %bb.1: // %A
+; CHECK-SVE2p1-NEXT: bl g0
+; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT: ret
+; CHECK-SVE2p1-NEXT: .LBB1_2: // %B
+; CHECK-SVE2p1-NEXT: bl g1
+; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT: ret
+E:
+ %wide.mask = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i32 %i, i32 %n)
+ %mask.hi = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1(<vscale x 16 x i1> %wide.mask, i64 8)
+ %mask = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1(<vscale x 16 x i1> %wide.mask, i64 0)
+ %elt = extractelement <vscale x 8 x i1> %mask, i64 0
+ br i1 %elt, label %A, label %B
+A:
+ call void @g0(<vscale x 8 x i1> %mask.hi)
+ ret void
+B:
+ call void @g1(<vscale x 8 x i1> %mask.hi)
+ ret void
+}
+
+declare void @g0(...)
+declare void @g1(...)
+
+attributes #0 = { nounwind vscale_range(1,16) "target-cpu"="neoverse-v1" }
>From b8c345999b455db4cf41a257821f265325e2fe02 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 12 Jun 2024 10:25:45 +0100
Subject: [PATCH 3/3] [AArch64] Optimise test of the LSB of a paired whileCC
instruction
Change-Id: Iefc0eb7e4b90715ae08c154dde5bda1091f9de07
---
.../Target/AArch64/AArch64ISelLowering.cpp | 63 ++++++++++++++-----
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 1 +
llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 52 ++++++++++-----
llvm/lib/Target/AArch64/AArch64InstrInfo.h | 3 +-
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 9 +--
llvm/lib/Target/AArch64/SVEInstrFormats.td | 17 +++--
llvm/test/CodeGen/AArch64/opt-while-test.ll | 26 +++-----
7 files changed, 109 insertions(+), 62 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c4f819f5fcdd2..1b296dc62fad0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2723,6 +2723,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::INSR)
MAKE_CASE(AArch64ISD::PTEST)
MAKE_CASE(AArch64ISD::PTEST_ANY)
+ MAKE_CASE(AArch64ISD::PTEST_FIRST)
MAKE_CASE(AArch64ISD::PTRUE)
MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO)
MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO)
@@ -18656,21 +18657,41 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
AArch64CC::CondCode Cond);
static bool isPredicateCCSettingOp(SDValue N) {
- if ((N.getOpcode() == ISD::SETCC) ||
- (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
- (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
- N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
- N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
- N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
- N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
- N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
- N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
- N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
- // get_active_lane_mask is lowered to a whilelo instruction.
- N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
+ if (N.getOpcode() == ISD::SETCC)
return true;
- return false;
+ if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+ isNullConstant(N.getOperand(1)))
+ N = N.getOperand(0);
+
+ if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+ return false;
+
+ switch (N.getConstantOperandVal(0)) {
+ default:
+ return false;
+ case Intrinsic::aarch64_sve_whilege_x2:
+ case Intrinsic::aarch64_sve_whilegt_x2:
+ case Intrinsic::aarch64_sve_whilehi_x2:
+ case Intrinsic::aarch64_sve_whilehs_x2:
+ case Intrinsic::aarch64_sve_whilele_x2:
+ case Intrinsic::aarch64_sve_whilelo_x2:
+ case Intrinsic::aarch64_sve_whilels_x2:
+ case Intrinsic::aarch64_sve_whilelt_x2:
+ if (N.getResNo() != 0)
+ return false;
+ [[fallthrough]];
+ case Intrinsic::aarch64_sve_whilege:
+ case Intrinsic::aarch64_sve_whilegt:
+ case Intrinsic::aarch64_sve_whilehi:
+ case Intrinsic::aarch64_sve_whilehs:
+ case Intrinsic::aarch64_sve_whilele:
+ case Intrinsic::aarch64_sve_whilelo:
+ case Intrinsic::aarch64_sve_whilels:
+ case Intrinsic::aarch64_sve_whilelt:
+ case Intrinsic::get_active_lane_mask:
+ return true;
+ }
}
// Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
@@ -20589,9 +20610,19 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
}
// Set condition code (CC) flags.
- SDValue Test = DAG.getNode(
- Cond == AArch64CC::ANY_ACTIVE ? AArch64ISD::PTEST_ANY : AArch64ISD::PTEST,
- DL, MVT::Other, Pg, Op);
+ AArch64ISD::NodeType NT;
+ switch (Cond) {
+ default:
+ NT = AArch64ISD::PTEST;
+ break;
+ case AArch64CC::ANY_ACTIVE:
+ NT = AArch64ISD::PTEST_ANY;
+ break;
+ case AArch64CC::FIRST_ACTIVE:
+ NT = AArch64ISD::PTEST_FIRST;
+ break;
+ }
+ SDValue Test = DAG.getNode(NT, DL, MVT::Other, Pg, Op);
// Convert CC to integer based on requested condition.
// NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 48a4ea91c2782..3de6536ee4461 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -356,6 +356,7 @@ enum NodeType : unsigned {
INSR,
PTEST,
PTEST_ANY,
+ PTEST_FIRST,
PTRUE,
CTTZ_ELTS,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 82061607f2e53..d3d53c004f68d 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1184,6 +1184,7 @@ bool AArch64InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg,
break;
case AArch64::PTEST_PP:
case AArch64::PTEST_PP_ANY:
+ case AArch64::PTEST_PP_FIRST:
SrcReg = MI.getOperand(0).getReg();
SrcReg2 = MI.getOperand(1).getReg();
// Not sure about the mask and value for now...
@@ -1355,12 +1356,25 @@ static bool areCFlagsAccessedBetweenInstrs(
return false;
}
-std::optional<unsigned>
+std::optional<std::pair<unsigned, MachineInstr *>>
AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
MachineInstr *Pred,
const MachineRegisterInfo *MRI) const {
unsigned MaskOpcode = Mask->getOpcode();
unsigned PredOpcode = Pred->getOpcode();
+
+ // Handle a COPY from the LSB of the results of paired WHILEcc instruction.
+ if ((PredOpcode == TargetOpcode::COPY &&
+ Pred->getOperand(1).getSubReg() == AArch64::psub0) ||
+ // Handle unpack of the LSB of the result of a WHILEcc instruction.
+ PredOpcode == AArch64::PUNPKLO_PP) {
+ MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+ if (MI && isWhileOpcode(MI->getOpcode())) {
+ Pred = MI;
+ PredOpcode = MI->getOpcode();
+ }
+ }
+
bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
bool PredIsWhileLike = isWhileOpcode(PredOpcode);
@@ -1369,15 +1383,16 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
// instruction and the condition is "any" since WHILcc does an implicit
// PTEST(ALL, PG) check and PG is always a subset of ALL.
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
- return PredOpcode;
+ return std::make_pair(PredOpcode, Pred);
- // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
- // redundant since WHILE performs an implicit PTEST with an all active
- // mask.
+ // For PTEST(PTRUE_ALL, WHILE), since WHILE performs an implicit PTEST
+ // with an all active mask, the PTEST is redundant if ether the element
+ // size matches or the PTEST condition is "first".
if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
- getElementSizeForOpcode(MaskOpcode) ==
- getElementSizeForOpcode(PredOpcode))
- return PredOpcode;
+ (PTest->getOpcode() == AArch64::PTEST_PP_FIRST ||
+ getElementSizeForOpcode(MaskOpcode) ==
+ getElementSizeForOpcode(PredOpcode)))
+ return std::make_pair(PredOpcode, Pred);
return {};
}
@@ -1388,7 +1403,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
// "any" since PG is always a subset of the governing predicate of the
// ptest-like instruction.
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
- return PredOpcode;
+ return std::make_pair(PredOpcode, Pred);
// For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
// the element size matches and either the PTEST_LIKE instruction uses
@@ -1398,7 +1413,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
getElementSizeForOpcode(PredOpcode)) {
auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
- return PredOpcode;
+ return std::make_pair(PredOpcode, Pred);
}
// For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
@@ -1427,7 +1442,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
PTest->getOpcode() == AArch64::PTEST_PP_ANY))
- return PredOpcode;
+ return std::make_pair(PredOpcode, Pred);
return {};
}
@@ -1471,7 +1486,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
return {};
}
- return convertToFlagSettingOpc(PredOpcode);
+ return std::make_pair(convertToFlagSettingOpc(PredOpcode), Pred);
}
/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
@@ -1481,10 +1496,12 @@ bool AArch64InstrInfo::optimizePTestInstr(
const MachineRegisterInfo *MRI) const {
auto *Mask = MRI->getUniqueVRegDef(MaskReg);
auto *Pred = MRI->getUniqueVRegDef(PredReg);
+ unsigned NewOp;
unsigned PredOpcode = Pred->getOpcode();
- auto NewOp = canRemovePTestInstr(PTest, Mask, Pred, MRI);
- if (!NewOp)
+ auto canRemove = canRemovePTestInstr(PTest, Mask, Pred, MRI);
+ if (!canRemove)
return false;
+ std::tie(NewOp, Pred) = *canRemove;
const TargetRegisterInfo *TRI = &getRegisterInfo();
@@ -1498,8 +1515,8 @@ bool AArch64InstrInfo::optimizePTestInstr(
// operand to be replaced with an equivalent instruction that also sets the
// flags.
PTest->eraseFromParent();
- if (*NewOp != PredOpcode) {
- Pred->setDesc(get(*NewOp));
+ if (NewOp != PredOpcode) {
+ Pred->setDesc(get(NewOp));
bool succeeded = UpdateOperandRegClass(*Pred);
(void)succeeded;
assert(succeeded && "Operands have incompatible register classes!");
@@ -1560,7 +1577,8 @@ bool AArch64InstrInfo::optimizeCompareInstr(
}
if (CmpInstr.getOpcode() == AArch64::PTEST_PP ||
- CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY)
+ CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY ||
+ CmpInstr.getOpcode() == AArch64::PTEST_PP_FIRST)
return optimizePTestInstr(&CmpInstr, SrcReg, SrcReg2, MRI);
if (SrcReg2 != 0)
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 792e0c3063b10..d722f433a150b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -572,7 +572,8 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
unsigned PredReg,
const MachineRegisterInfo *MRI) const;
- std::optional<unsigned>
+
+ std::optional<std::pair<unsigned, MachineInstr *>>
canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
};
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index bd5de628d8529..3cee3e92fae08 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -373,9 +373,10 @@ def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
(AArch64fadda_p_node (SVEAllActive), node:$op2,
(vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;
-def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
-def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
-def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
+def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
+def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
+def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
+def AArch64ptest_first : SDNode<"AArch64ISD::PTEST_FIRST", SDT_AArch64PTest>;
def SDT_AArch64DUP_PRED : SDTypeProfile<1, 3,
[SDTCisVec<0>, SDTCisSameAs<0, 3>, SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0, 1>]>;
@@ -948,7 +949,7 @@ let Predicates = [HasSVEorSME] in {
defm BRKB_PPmP : sve_int_break_m<0b101, "brkb", int_aarch64_sve_brkb>;
defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>;
- defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any>;
+ defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any, AArch64ptest_first>;
defm PFALSE : sve_int_pfalse<0b000000, "pfalse">;
defm PFIRST : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>;
defm PNEXT : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index fc7d3cdda4acd..1c3528bed08c4 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -784,13 +784,16 @@ class sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op>
}
multiclass sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op,
- SDPatternOperator op_any> {
+ SDPatternOperator op_any, SDPatternOperator op_first> {
def NAME : sve_int_ptest<opc, asm, op>;
let hasNoSchedulingInfo = 1, isCompare = 1, Defs = [NZCV] in {
def _ANY : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
[(op_any (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
+ def _FIRST : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
+ [(op_first (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
+ PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
}
}
@@ -9669,7 +9672,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
// SVE integer compare scalar count and limit (predicate pair)
class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
- RegisterOperand ppr_ty>
+ RegisterOperand ppr_ty, ElementSizeEnum EltSz>
: I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
mnemonic, "\t$Pd, $Rn, $Rm",
"", []>, Sched<[]> {
@@ -9687,16 +9690,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
let Inst{3-1} = Pd;
let Inst{0} = opc{0};
+ let ElementSize = EltSz;
let Defs = [NZCV];
let hasSideEffects = 0;
+ let isWhile = 1;
}
multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
- def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
- def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
- def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
- def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
+ def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
+ def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
+ def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
+ def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
}
diff --git a/llvm/test/CodeGen/AArch64/opt-while-test.ll b/llvm/test/CodeGen/AArch64/opt-while-test.ll
index 3a43f817d6154..a022f4d8c9e23 100644
--- a/llvm/test/CodeGen/AArch64/opt-while-test.ll
+++ b/llvm/test/CodeGen/AArch64/opt-while-test.ll
@@ -8,10 +8,7 @@ define void @f_while(i32 %i, i32 %n) #0 {
; CHECK: // %bb.0: // %E
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: whilelo p0.b, w0, w1
-; CHECK-NEXT: punpklo p0.h, p0.b
-; CHECK-NEXT: mov z0.h, p0/z, #1 // =0x1
-; CHECK-NEXT: fmov w8, s0
-; CHECK-NEXT: tbz w8, #0, .LBB0_2
+; CHECK-NEXT: b.pl .LBB0_2
; CHECK-NEXT: // %bb.1: // %A
; CHECK-NEXT: bl g0
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -25,10 +22,7 @@ define void @f_while(i32 %i, i32 %n) #0 {
; CHECK-SVE2p1: // %bb.0: // %E
; CHECK-SVE2p1-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-SVE2p1-NEXT: whilelo p0.b, w0, w1
-; CHECK-SVE2p1-NEXT: punpklo p0.h, p0.b
-; CHECK-SVE2p1-NEXT: mov z0.h, p0/z, #1 // =0x1
-; CHECK-SVE2p1-NEXT: fmov w8, s0
-; CHECK-SVE2p1-NEXT: tbz w8, #0, .LBB0_2
+; CHECK-SVE2p1-NEXT: b.pl .LBB0_2
; CHECK-SVE2p1-NEXT: // %bb.1: // %A
; CHECK-SVE2p1-NEXT: bl g0
; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -54,12 +48,9 @@ define void @f_while_x2(i32 %i, i32 %n) #0 {
; CHECK-LABEL: f_while_x2:
; CHECK: // %bb.0: // %E
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT: whilelo p1.b, w0, w1
-; CHECK-NEXT: punpkhi p0.h, p1.b
-; CHECK-NEXT: punpklo p1.h, p1.b
-; CHECK-NEXT: mov z0.h, p1/z, #1 // =0x1
-; CHECK-NEXT: fmov w8, s0
-; CHECK-NEXT: tbz w8, #0, .LBB1_2
+; CHECK-NEXT: whilelo p0.b, w0, w1
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: b.pl .LBB1_2
; CHECK-NEXT: // %bb.1: // %A
; CHECK-NEXT: bl g0
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -75,15 +66,14 @@ define void @f_while_x2(i32 %i, i32 %n) #0 {
; CHECK-SVE2p1-NEXT: mov w8, w1
; CHECK-SVE2p1-NEXT: mov w9, w0
; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x9, x8
-; CHECK-SVE2p1-NEXT: mov z0.h, p0/z, #1 // =0x1
-; CHECK-SVE2p1-NEXT: mov p0.b, p1.b
-; CHECK-SVE2p1-NEXT: fmov w8, s0
-; CHECK-SVE2p1-NEXT: tbz w8, #0, .LBB1_2
+; CHECK-SVE2p1-NEXT: b.pl .LBB1_2
; CHECK-SVE2p1-NEXT: // %bb.1: // %A
+; CHECK-SVE2p1-NEXT: mov p0.b, p1.b
; CHECK-SVE2p1-NEXT: bl g0
; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-SVE2p1-NEXT: ret
; CHECK-SVE2p1-NEXT: .LBB1_2: // %B
+; CHECK-SVE2p1-NEXT: mov p0.b, p1.b
; CHECK-SVE2p1-NEXT: bl g1
; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-SVE2p1-NEXT: ret
More information about the llvm-commits
mailing list