[llvm] [AArch64] Optimise test of the LSB of a paired whileCC instruction (PR #81141)

Momchil Velikov via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 12 07:20:38 PDT 2024


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/81141

>From 2f5e80c5d218c10a480b0c816883e8d93d359b31 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 11 Jun 2024 14:07:17 +0100
Subject: [PATCH 1/3] [AArch64] Refactor redundant PTEST optimisations (NFC)

Change-Id: I63ff6f4a7f90cd584508cbaa8bba8a39a8ca3f56
---
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 171 ++++++++++---------
 llvm/lib/Target/AArch64/AArch64InstrInfo.h   |   3 +
 2 files changed, 96 insertions(+), 78 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index a5135b78bded9..82061607f2e53 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1355,48 +1355,52 @@ static bool areCFlagsAccessedBetweenInstrs(
   return false;
 }
 
-/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
-/// operation which could set the flags in an identical manner
-bool AArch64InstrInfo::optimizePTestInstr(
-    MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
-    const MachineRegisterInfo *MRI) const {
-  auto *Mask = MRI->getUniqueVRegDef(MaskReg);
-  auto *Pred = MRI->getUniqueVRegDef(PredReg);
-  auto NewOp = Pred->getOpcode();
-  bool OpChanged = false;
-
+std::optional<unsigned>
+AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+                                      MachineInstr *Pred,
+                                      const MachineRegisterInfo *MRI) const {
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
-  if (isPTrueOpcode(MaskOpcode) && (PredIsPTestLike || PredIsWhileLike) &&
-      getElementSizeForOpcode(MaskOpcode) ==
-          getElementSizeForOpcode(PredOpcode) &&
-      Mask->getOperand(1).getImm() == 31) {
+  if (PredIsWhileLike) {
+    // For PTEST(PG, PG), PTEST is redundant when PG is the result of a WHILEcc
+    // instruction and the condition is "any" since WHILcc does an implicit
+    // PTEST(ALL, PG) check and PG is always a subset of ALL.
+    if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+      return PredOpcode;
+
     // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
     // redundant since WHILE performs an implicit PTEST with an all active
-    // mask. Must be an all active predicate of matching element size.
+    // mask.
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+        getElementSizeForOpcode(MaskOpcode) ==
+            getElementSizeForOpcode(PredOpcode))
+      return PredOpcode;
+
+    return {};
+  }
+
+  if (PredIsPTestLike) {
+    // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
+    // instruction that sets the flags as PTEST would and the condition is
+    // "any" since PG is always a subset of the governing predicate of the
+    // ptest-like instruction.
+    if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+      return PredOpcode;
 
     // For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
-    // PTEST_LIKE instruction uses the same all active mask and the element
-    // size matches. If the PTEST has a condition of any then it is always
-    // redundant.
-    if (PredIsPTestLike) {
+    // the element size matches and either the PTEST_LIKE instruction uses
+    // the same all active mask or the condition is "any".
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+        getElementSizeForOpcode(MaskOpcode) ==
+            getElementSizeForOpcode(PredOpcode)) {
       auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
-      if (Mask != PTestLikeMask && PTest->getOpcode() != AArch64::PTEST_PP_ANY)
-        return false;
+      if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+        return PredOpcode;
     }
 
-    // Fallthough to simply remove the PTEST.
-  } else if ((Mask == Pred) && (PredIsPTestLike || PredIsWhileLike) &&
-             PTest->getOpcode() == AArch64::PTEST_PP_ANY) {
-    // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
-    // instruction that sets the flags as PTEST would. This is only valid when
-    // the condition is any.
-
-    // Fallthough to simply remove the PTEST.
-  } else if (PredIsPTestLike) {
     // For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
     // flags are set based on the same mask 'PG', but PTEST_LIKE must operate
     // on 8-bit predicates like the PTEST.  Otherwise, for instructions like
@@ -1421,55 +1425,66 @@ bool AArch64InstrInfo::optimizePTestInstr(
     // identical regardless of element size.
     auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
     uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
-    if ((Mask != PTestLikeMask) ||
-        (PredElementSize != AArch64::ElementSizeB &&
-         PTest->getOpcode() != AArch64::PTEST_PP_ANY))
-      return false;
+    if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
+                                  PTest->getOpcode() == AArch64::PTEST_PP_ANY))
+      return PredOpcode;
 
-    // Fallthough to simply remove the PTEST.
-  } else {
-    // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
-    // opcode so the PTEST becomes redundant.
-    switch (PredOpcode) {
-    case AArch64::AND_PPzPP:
-    case AArch64::BIC_PPzPP:
-    case AArch64::EOR_PPzPP:
-    case AArch64::NAND_PPzPP:
-    case AArch64::NOR_PPzPP:
-    case AArch64::ORN_PPzPP:
-    case AArch64::ORR_PPzPP:
-    case AArch64::BRKA_PPzP:
-    case AArch64::BRKPA_PPzPP:
-    case AArch64::BRKB_PPzP:
-    case AArch64::BRKPB_PPzPP:
-    case AArch64::RDFFR_PPz: {
-      // Check to see if our mask is the same. If not the resulting flag bits
-      // may be different and we can't remove the ptest.
-      auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
-      if (Mask != PredMask)
-        return false;
-      break;
-    }
-    case AArch64::BRKN_PPzP: {
-      // BRKN uses an all active implicit mask to set flags unlike the other
-      // flag-setting instructions.
-      // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
-      if ((MaskOpcode != AArch64::PTRUE_B) ||
-          (Mask->getOperand(1).getImm() != 31))
-        return false;
-      break;
-    }
-    case AArch64::PTRUE_B:
-      // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
-      break;
-    default:
-      // Bail out if we don't recognize the input
-      return false;
-    }
+    return {};
+  }
 
-    NewOp = convertToFlagSettingOpc(PredOpcode);
-    OpChanged = true;
+  // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
+  // opcode so the PTEST becomes redundant.
+  switch (PredOpcode) {
+  case AArch64::AND_PPzPP:
+  case AArch64::BIC_PPzPP:
+  case AArch64::EOR_PPzPP:
+  case AArch64::NAND_PPzPP:
+  case AArch64::NOR_PPzPP:
+  case AArch64::ORN_PPzPP:
+  case AArch64::ORR_PPzPP:
+  case AArch64::BRKA_PPzP:
+  case AArch64::BRKPA_PPzPP:
+  case AArch64::BRKB_PPzP:
+  case AArch64::BRKPB_PPzPP:
+  case AArch64::RDFFR_PPz: {
+    // Check to see if our mask is the same. If not the resulting flag bits
+    // may be different and we can't remove the ptest.
+    auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (Mask != PredMask)
+      return {};
+    break;
   }
+  case AArch64::BRKN_PPzP: {
+    // BRKN uses an all active implicit mask to set flags unlike the other
+    // flag-setting instructions.
+    // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
+    if ((MaskOpcode != AArch64::PTRUE_B) ||
+        (Mask->getOperand(1).getImm() != 31))
+      return {};
+    break;
+  }
+  case AArch64::PTRUE_B:
+    // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
+    break;
+  default:
+    // Bail out if we don't recognize the input
+    return {};
+  }
+
+  return convertToFlagSettingOpc(PredOpcode);
+}
+
+/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
+/// operation which could set the flags in an identical manner
+bool AArch64InstrInfo::optimizePTestInstr(
+    MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
+    const MachineRegisterInfo *MRI) const {
+  auto *Mask = MRI->getUniqueVRegDef(MaskReg);
+  auto *Pred = MRI->getUniqueVRegDef(PredReg);
+  unsigned PredOpcode = Pred->getOpcode();
+  auto NewOp = canRemovePTestInstr(PTest, Mask, Pred, MRI);
+  if (!NewOp)
+    return false;
 
   const TargetRegisterInfo *TRI = &getRegisterInfo();
 
@@ -1482,9 +1497,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // as they are prior to PTEST. Sometimes this requires the tested PTEST
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
-  Pred->setDesc(get(NewOp));
   PTest->eraseFromParent();
-  if (OpChanged) {
+  if (*NewOp != PredOpcode) {
+    Pred->setDesc(get(*NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index f434799c3982b..792e0c3063b10 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -572,6 +572,9 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
   bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
                           unsigned PredReg,
                           const MachineRegisterInfo *MRI) const;
+  std::optional<unsigned>
+  canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+                      MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
 };
 
 struct UsedNZCV {

>From 3a79e1d9c1a00f08e3661722571720723a01dd74 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 11 Jun 2024 15:51:15 +0100
Subject: [PATCH 2/3] [AArch64] Precommit testcase for optimised test of the
 LSB of a paired whileCC instruction

Change-Id: I5058e24c631ede0a04399b39e5096f898fa8f792
---
 llvm/test/CodeGen/AArch64/opt-while-test.ll | 107 ++++++++++++++++++++
 1 file changed, 107 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/opt-while-test.ll

diff --git a/llvm/test/CodeGen/AArch64/opt-while-test.ll b/llvm/test/CodeGen/AArch64/opt-while-test.ll
new file mode 100644
index 0000000000000..3a43f817d6154
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/opt-while-test.ll
@@ -0,0 +1,107 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc                < %s | FileCheck %s
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s --check-prefix=CHECK-SVE2p1
+target triple = "aarch64-linux"
+
+define void @f_while(i32 %i, i32 %n) #0 {
+; CHECK-LABEL: f_while:
+; CHECK:       // %bb.0: // %E
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    whilelo p0.b, w0, w1
+; CHECK-NEXT:    punpklo p0.h, p0.b
+; CHECK-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    tbz w8, #0, .LBB0_2
+; CHECK-NEXT:  // %bb.1: // %A
+; CHECK-NEXT:    bl g0
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+; CHECK-NEXT:  .LBB0_2: // %B
+; CHECK-NEXT:    bl g1
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+;
+; CHECK-SVE2p1-LABEL: f_while:
+; CHECK-SVE2p1:       // %bb.0: // %E
+; CHECK-SVE2p1-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SVE2p1-NEXT:    whilelo p0.b, w0, w1
+; CHECK-SVE2p1-NEXT:    punpklo p0.h, p0.b
+; CHECK-SVE2p1-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-SVE2p1-NEXT:    fmov w8, s0
+; CHECK-SVE2p1-NEXT:    tbz w8, #0, .LBB0_2
+; CHECK-SVE2p1-NEXT:  // %bb.1: // %A
+; CHECK-SVE2p1-NEXT:    bl g0
+; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT:    ret
+; CHECK-SVE2p1-NEXT:  .LBB0_2: // %B
+; CHECK-SVE2p1-NEXT:    bl g1
+; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT:    ret
+E:
+  %wide.mask = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i32 %i, i32 %n)
+  %mask = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1(<vscale x 16 x i1> %wide.mask, i64 0)
+  %elt = extractelement <vscale x 8 x i1> %mask, i64 0
+  br i1 %elt, label %A, label %B
+A:
+  call void @g0()
+  ret void
+B:
+  call void @g1()
+  ret void
+}
+
+define void @f_while_x2(i32 %i, i32 %n) #0 {
+; CHECK-LABEL: f_while_x2:
+; CHECK:       // %bb.0: // %E
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    whilelo p1.b, w0, w1
+; CHECK-NEXT:    punpkhi p0.h, p1.b
+; CHECK-NEXT:    punpklo p1.h, p1.b
+; CHECK-NEXT:    mov z0.h, p1/z, #1 // =0x1
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    tbz w8, #0, .LBB1_2
+; CHECK-NEXT:  // %bb.1: // %A
+; CHECK-NEXT:    bl g0
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+; CHECK-NEXT:  .LBB1_2: // %B
+; CHECK-NEXT:    bl g1
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+;
+; CHECK-SVE2p1-LABEL: f_while_x2:
+; CHECK-SVE2p1:       // %bb.0: // %E
+; CHECK-SVE2p1-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SVE2p1-NEXT:    mov w8, w1
+; CHECK-SVE2p1-NEXT:    mov w9, w0
+; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SVE2p1-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-SVE2p1-NEXT:    mov p0.b, p1.b
+; CHECK-SVE2p1-NEXT:    fmov w8, s0
+; CHECK-SVE2p1-NEXT:    tbz w8, #0, .LBB1_2
+; CHECK-SVE2p1-NEXT:  // %bb.1: // %A
+; CHECK-SVE2p1-NEXT:    bl g0
+; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT:    ret
+; CHECK-SVE2p1-NEXT:  .LBB1_2: // %B
+; CHECK-SVE2p1-NEXT:    bl g1
+; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT:    ret
+E:
+  %wide.mask = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i32 %i, i32 %n)
+  %mask.hi = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1(<vscale x 16 x i1> %wide.mask, i64 8)
+  %mask = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1(<vscale x 16 x i1> %wide.mask, i64 0)
+  %elt = extractelement <vscale x 8 x i1> %mask, i64 0
+  br i1 %elt, label %A, label %B
+A:
+  call void @g0(<vscale x 8 x i1> %mask.hi)
+  ret void
+B:
+  call void @g1(<vscale x 8 x i1> %mask.hi)
+  ret void
+}
+
+declare void @g0(...)
+declare void @g1(...)
+
+attributes #0 = { nounwind vscale_range(1,16) "target-cpu"="neoverse-v1" }

>From b8c345999b455db4cf41a257821f265325e2fe02 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 12 Jun 2024 10:25:45 +0100
Subject: [PATCH 3/3] [AArch64] Optimise test of the LSB of a paired whileCC
 instruction

Change-Id: Iefc0eb7e4b90715ae08c154dde5bda1091f9de07
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 63 ++++++++++++++-----
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  1 +
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp  | 52 ++++++++++-----
 llvm/lib/Target/AArch64/AArch64InstrInfo.h    |  3 +-
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  9 +--
 llvm/lib/Target/AArch64/SVEInstrFormats.td    | 17 +++--
 llvm/test/CodeGen/AArch64/opt-while-test.ll   | 26 +++-----
 7 files changed, 109 insertions(+), 62 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c4f819f5fcdd2..1b296dc62fad0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2723,6 +2723,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::INSR)
     MAKE_CASE(AArch64ISD::PTEST)
     MAKE_CASE(AArch64ISD::PTEST_ANY)
+    MAKE_CASE(AArch64ISD::PTEST_FIRST)
     MAKE_CASE(AArch64ISD::PTRUE)
     MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO)
     MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO)
@@ -18656,21 +18657,41 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
                         AArch64CC::CondCode Cond);
 
 static bool isPredicateCCSettingOp(SDValue N) {
-  if ((N.getOpcode() == ISD::SETCC) ||
-      (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-       (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
-        // get_active_lane_mask is lowered to a whilelo instruction.
-        N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
+  if (N.getOpcode() == ISD::SETCC)
     return true;
 
-  return false;
+  if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+      isNullConstant(N.getOperand(1)))
+    N = N.getOperand(0);
+
+  if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+    return false;
+
+  switch (N.getConstantOperandVal(0)) {
+  default:
+    return false;
+  case Intrinsic::aarch64_sve_whilege_x2:
+  case Intrinsic::aarch64_sve_whilegt_x2:
+  case Intrinsic::aarch64_sve_whilehi_x2:
+  case Intrinsic::aarch64_sve_whilehs_x2:
+  case Intrinsic::aarch64_sve_whilele_x2:
+  case Intrinsic::aarch64_sve_whilelo_x2:
+  case Intrinsic::aarch64_sve_whilels_x2:
+  case Intrinsic::aarch64_sve_whilelt_x2:
+    if (N.getResNo() != 0)
+      return false;
+    [[fallthrough]];
+  case Intrinsic::aarch64_sve_whilege:
+  case Intrinsic::aarch64_sve_whilegt:
+  case Intrinsic::aarch64_sve_whilehi:
+  case Intrinsic::aarch64_sve_whilehs:
+  case Intrinsic::aarch64_sve_whilele:
+  case Intrinsic::aarch64_sve_whilelo:
+  case Intrinsic::aarch64_sve_whilels:
+  case Intrinsic::aarch64_sve_whilelt:
+  case Intrinsic::get_active_lane_mask:
+    return true;
+  }
 }
 
 // Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
@@ -20589,9 +20610,19 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
   }
 
   // Set condition code (CC) flags.
-  SDValue Test = DAG.getNode(
-      Cond == AArch64CC::ANY_ACTIVE ? AArch64ISD::PTEST_ANY : AArch64ISD::PTEST,
-      DL, MVT::Other, Pg, Op);
+  AArch64ISD::NodeType NT;
+  switch (Cond) {
+  default:
+    NT = AArch64ISD::PTEST;
+    break;
+  case AArch64CC::ANY_ACTIVE:
+    NT = AArch64ISD::PTEST_ANY;
+    break;
+  case AArch64CC::FIRST_ACTIVE:
+    NT = AArch64ISD::PTEST_FIRST;
+    break;
+  }
+  SDValue Test = DAG.getNode(NT, DL, MVT::Other, Pg, Op);
 
   // Convert CC to integer based on requested condition.
   // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 48a4ea91c2782..3de6536ee4461 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -356,6 +356,7 @@ enum NodeType : unsigned {
   INSR,
   PTEST,
   PTEST_ANY,
+  PTEST_FIRST,
   PTRUE,
 
   CTTZ_ELTS,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 82061607f2e53..d3d53c004f68d 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1184,6 +1184,7 @@ bool AArch64InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg,
     break;
   case AArch64::PTEST_PP:
   case AArch64::PTEST_PP_ANY:
+  case AArch64::PTEST_PP_FIRST:
     SrcReg = MI.getOperand(0).getReg();
     SrcReg2 = MI.getOperand(1).getReg();
     // Not sure about the mask and value for now...
@@ -1355,12 +1356,25 @@ static bool areCFlagsAccessedBetweenInstrs(
   return false;
 }
 
-std::optional<unsigned>
+std::optional<std::pair<unsigned, MachineInstr *>>
 AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
                                       MachineInstr *Pred,
                                       const MachineRegisterInfo *MRI) const {
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
+
+  // Handle a COPY from the LSB of the results of paired WHILEcc instruction.
+  if ((PredOpcode == TargetOpcode::COPY &&
+       Pred->getOperand(1).getSubReg() == AArch64::psub0) ||
+      // Handle unpack of the LSB of the result of a WHILEcc instruction.
+      PredOpcode == AArch64::PUNPKLO_PP) {
+    MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (MI && isWhileOpcode(MI->getOpcode())) {
+      Pred = MI;
+      PredOpcode = MI->getOpcode();
+    }
+  }
+
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
@@ -1369,15 +1383,16 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     // instruction and the condition is "any" since WHILcc does an implicit
     // PTEST(ALL, PG) check and PG is always a subset of ALL.
     if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
-      return PredOpcode;
+      return std::make_pair(PredOpcode, Pred);
 
-    // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
-    // redundant since WHILE performs an implicit PTEST with an all active
-    // mask.
+    // For PTEST(PTRUE_ALL, WHILE), since WHILE performs an implicit PTEST
+    // with an all active mask, the PTEST is redundant if ether the element
+    // size matches or the PTEST condition is "first".
     if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
-        getElementSizeForOpcode(MaskOpcode) ==
-            getElementSizeForOpcode(PredOpcode))
-      return PredOpcode;
+        (PTest->getOpcode() == AArch64::PTEST_PP_FIRST ||
+         getElementSizeForOpcode(MaskOpcode) ==
+             getElementSizeForOpcode(PredOpcode)))
+      return std::make_pair(PredOpcode, Pred);
 
     return {};
   }
@@ -1388,7 +1403,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     // "any" since PG is always a subset of the governing predicate of the
     // ptest-like instruction.
     if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
-      return PredOpcode;
+      return std::make_pair(PredOpcode, Pred);
 
     // For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
     // the element size matches and either the PTEST_LIKE instruction uses
@@ -1398,7 +1413,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
             getElementSizeForOpcode(PredOpcode)) {
       auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
       if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
-        return PredOpcode;
+        return std::make_pair(PredOpcode, Pred);
     }
 
     // For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
@@ -1427,7 +1442,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
     if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
                                   PTest->getOpcode() == AArch64::PTEST_PP_ANY))
-      return PredOpcode;
+      return std::make_pair(PredOpcode, Pred);
 
     return {};
   }
@@ -1471,7 +1486,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     return {};
   }
 
-  return convertToFlagSettingOpc(PredOpcode);
+  return std::make_pair(convertToFlagSettingOpc(PredOpcode), Pred);
 }
 
 /// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
@@ -1481,10 +1496,12 @@ bool AArch64InstrInfo::optimizePTestInstr(
     const MachineRegisterInfo *MRI) const {
   auto *Mask = MRI->getUniqueVRegDef(MaskReg);
   auto *Pred = MRI->getUniqueVRegDef(PredReg);
+  unsigned NewOp;
   unsigned PredOpcode = Pred->getOpcode();
-  auto NewOp = canRemovePTestInstr(PTest, Mask, Pred, MRI);
-  if (!NewOp)
+  auto canRemove = canRemovePTestInstr(PTest, Mask, Pred, MRI);
+  if (!canRemove)
     return false;
+  std::tie(NewOp, Pred) = *canRemove;
 
   const TargetRegisterInfo *TRI = &getRegisterInfo();
 
@@ -1498,8 +1515,8 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
   PTest->eraseFromParent();
-  if (*NewOp != PredOpcode) {
-    Pred->setDesc(get(*NewOp));
+  if (NewOp != PredOpcode) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
@@ -1560,7 +1577,8 @@ bool AArch64InstrInfo::optimizeCompareInstr(
   }
 
   if (CmpInstr.getOpcode() == AArch64::PTEST_PP ||
-      CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY)
+      CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY ||
+      CmpInstr.getOpcode() == AArch64::PTEST_PP_FIRST)
     return optimizePTestInstr(&CmpInstr, SrcReg, SrcReg2, MRI);
 
   if (SrcReg2 != 0)
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 792e0c3063b10..d722f433a150b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -572,7 +572,8 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
   bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
                           unsigned PredReg,
                           const MachineRegisterInfo *MRI) const;
-  std::optional<unsigned>
+
+  std::optional<std::pair<unsigned, MachineInstr *>>
   canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
                       MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
 };
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index bd5de628d8529..3cee3e92fae08 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -373,9 +373,10 @@ def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
      (AArch64fadda_p_node (SVEAllActive), node:$op2,
              (vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;
 
-def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
-def AArch64ptest     : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
-def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
+def SDT_AArch64PTest   : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
+def AArch64ptest       : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
+def AArch64ptest_any   : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
+def AArch64ptest_first : SDNode<"AArch64ISD::PTEST_FIRST", SDT_AArch64PTest>;
 
 def SDT_AArch64DUP_PRED  : SDTypeProfile<1, 3,
   [SDTCisVec<0>, SDTCisSameAs<0, 3>, SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0, 1>]>;
@@ -948,7 +949,7 @@ let Predicates = [HasSVEorSME] in {
   defm BRKB_PPmP  : sve_int_break_m<0b101, "brkb",  int_aarch64_sve_brkb>;
   defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>;
 
-  defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any>;
+  defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any, AArch64ptest_first>;
   defm PFALSE   : sve_int_pfalse<0b000000, "pfalse">;
   defm PFIRST   : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>;
   defm PNEXT    : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index fc7d3cdda4acd..1c3528bed08c4 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -784,13 +784,16 @@ class sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op>
 }
 
 multiclass sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op,
-                         SDPatternOperator op_any> {
+                         SDPatternOperator op_any, SDPatternOperator op_first> {
   def NAME : sve_int_ptest<opc, asm, op>;
 
   let hasNoSchedulingInfo = 1, isCompare = 1, Defs = [NZCV] in {
   def _ANY : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
                     [(op_any (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
              PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
+  def _FIRST : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
+                    [(op_first (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
+             PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
   }
 }
 
@@ -9669,7 +9672,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
 
 // SVE integer compare scalar count and limit (predicate pair)
 class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
-                             RegisterOperand ppr_ty>
+                             RegisterOperand ppr_ty, ElementSizeEnum EltSz>
     : I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
         mnemonic, "\t$Pd, $Rn, $Rm",
         "", []>, Sched<[]> {
@@ -9687,16 +9690,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
   let Inst{3-1}   = Pd;
   let Inst{0}     = opc{0};
 
+  let ElementSize = EltSz;
   let Defs = [NZCV];
   let hasSideEffects = 0;
+  let isWhile = 1;
 }
 
 
 multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
- def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
- def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
- def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
- def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
+ def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
+ def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
+ def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
+ def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
 }
 
 
diff --git a/llvm/test/CodeGen/AArch64/opt-while-test.ll b/llvm/test/CodeGen/AArch64/opt-while-test.ll
index 3a43f817d6154..a022f4d8c9e23 100644
--- a/llvm/test/CodeGen/AArch64/opt-while-test.ll
+++ b/llvm/test/CodeGen/AArch64/opt-while-test.ll
@@ -8,10 +8,7 @@ define void @f_while(i32 %i, i32 %n) #0 {
 ; CHECK:       // %bb.0: // %E
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
 ; CHECK-NEXT:    whilelo p0.b, w0, w1
-; CHECK-NEXT:    punpklo p0.h, p0.b
-; CHECK-NEXT:    mov z0.h, p0/z, #1 // =0x1
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    tbz w8, #0, .LBB0_2
+; CHECK-NEXT:    b.pl .LBB0_2
 ; CHECK-NEXT:  // %bb.1: // %A
 ; CHECK-NEXT:    bl g0
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -25,10 +22,7 @@ define void @f_while(i32 %i, i32 %n) #0 {
 ; CHECK-SVE2p1:       // %bb.0: // %E
 ; CHECK-SVE2p1-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
 ; CHECK-SVE2p1-NEXT:    whilelo p0.b, w0, w1
-; CHECK-SVE2p1-NEXT:    punpklo p0.h, p0.b
-; CHECK-SVE2p1-NEXT:    mov z0.h, p0/z, #1 // =0x1
-; CHECK-SVE2p1-NEXT:    fmov w8, s0
-; CHECK-SVE2p1-NEXT:    tbz w8, #0, .LBB0_2
+; CHECK-SVE2p1-NEXT:    b.pl .LBB0_2
 ; CHECK-SVE2p1-NEXT:  // %bb.1: // %A
 ; CHECK-SVE2p1-NEXT:    bl g0
 ; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -54,12 +48,9 @@ define void @f_while_x2(i32 %i, i32 %n) #0 {
 ; CHECK-LABEL: f_while_x2:
 ; CHECK:       // %bb.0: // %E
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT:    whilelo p1.b, w0, w1
-; CHECK-NEXT:    punpkhi p0.h, p1.b
-; CHECK-NEXT:    punpklo p1.h, p1.b
-; CHECK-NEXT:    mov z0.h, p1/z, #1 // =0x1
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    tbz w8, #0, .LBB1_2
+; CHECK-NEXT:    whilelo p0.b, w0, w1
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    b.pl .LBB1_2
 ; CHECK-NEXT:  // %bb.1: // %A
 ; CHECK-NEXT:    bl g0
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -75,15 +66,14 @@ define void @f_while_x2(i32 %i, i32 %n) #0 {
 ; CHECK-SVE2p1-NEXT:    mov w8, w1
 ; CHECK-SVE2p1-NEXT:    mov w9, w0
 ; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x9, x8
-; CHECK-SVE2p1-NEXT:    mov z0.h, p0/z, #1 // =0x1
-; CHECK-SVE2p1-NEXT:    mov p0.b, p1.b
-; CHECK-SVE2p1-NEXT:    fmov w8, s0
-; CHECK-SVE2p1-NEXT:    tbz w8, #0, .LBB1_2
+; CHECK-SVE2p1-NEXT:    b.pl .LBB1_2
 ; CHECK-SVE2p1-NEXT:  // %bb.1: // %A
+; CHECK-SVE2p1-NEXT:    mov p0.b, p1.b
 ; CHECK-SVE2p1-NEXT:    bl g0
 ; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-SVE2p1-NEXT:    ret
 ; CHECK-SVE2p1-NEXT:  .LBB1_2: // %B
+; CHECK-SVE2p1-NEXT:    mov p0.b, p1.b
 ; CHECK-SVE2p1-NEXT:    bl g1
 ; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-SVE2p1-NEXT:    ret



More information about the llvm-commits mailing list