[llvm] [AArch64] Optimise test of the LSB of a paired whileCC instruction (PR #81141)

Momchil Velikov via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 19 01:33:32 PDT 2024


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/81141

>From 07e231b3d4793426c2789b2f1ee498ad72c9e7ba Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 11 Jun 2024 15:51:15 +0100
Subject: [PATCH 1/2] [AArch64] Precommit testcase for optimised test of the
 LSB of a paired whileCC instruction

Change-Id: I5058e24c631ede0a04399b39e5096f898fa8f792
---
 llvm/test/CodeGen/AArch64/opt-while-test.ll | 107 ++++++++++++++++++++
 1 file changed, 107 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/opt-while-test.ll

diff --git a/llvm/test/CodeGen/AArch64/opt-while-test.ll b/llvm/test/CodeGen/AArch64/opt-while-test.ll
new file mode 100644
index 0000000000000..3a43f817d6154
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/opt-while-test.ll
@@ -0,0 +1,107 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc                < %s | FileCheck %s
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s --check-prefix=CHECK-SVE2p1
+target triple = "aarch64-linux"
+
+define void @f_while(i32 %i, i32 %n) #0 {
+; CHECK-LABEL: f_while:
+; CHECK:       // %bb.0: // %E
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    whilelo p0.b, w0, w1
+; CHECK-NEXT:    punpklo p0.h, p0.b
+; CHECK-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    tbz w8, #0, .LBB0_2
+; CHECK-NEXT:  // %bb.1: // %A
+; CHECK-NEXT:    bl g0
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+; CHECK-NEXT:  .LBB0_2: // %B
+; CHECK-NEXT:    bl g1
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+;
+; CHECK-SVE2p1-LABEL: f_while:
+; CHECK-SVE2p1:       // %bb.0: // %E
+; CHECK-SVE2p1-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SVE2p1-NEXT:    whilelo p0.b, w0, w1
+; CHECK-SVE2p1-NEXT:    punpklo p0.h, p0.b
+; CHECK-SVE2p1-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-SVE2p1-NEXT:    fmov w8, s0
+; CHECK-SVE2p1-NEXT:    tbz w8, #0, .LBB0_2
+; CHECK-SVE2p1-NEXT:  // %bb.1: // %A
+; CHECK-SVE2p1-NEXT:    bl g0
+; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT:    ret
+; CHECK-SVE2p1-NEXT:  .LBB0_2: // %B
+; CHECK-SVE2p1-NEXT:    bl g1
+; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT:    ret
+E:
+  %wide.mask = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i32 %i, i32 %n)
+  %mask = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1(<vscale x 16 x i1> %wide.mask, i64 0)
+  %elt = extractelement <vscale x 8 x i1> %mask, i64 0
+  br i1 %elt, label %A, label %B
+A:
+  call void @g0()
+  ret void
+B:
+  call void @g1()
+  ret void
+}
+
+define void @f_while_x2(i32 %i, i32 %n) #0 {
+; CHECK-LABEL: f_while_x2:
+; CHECK:       // %bb.0: // %E
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    whilelo p1.b, w0, w1
+; CHECK-NEXT:    punpkhi p0.h, p1.b
+; CHECK-NEXT:    punpklo p1.h, p1.b
+; CHECK-NEXT:    mov z0.h, p1/z, #1 // =0x1
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    tbz w8, #0, .LBB1_2
+; CHECK-NEXT:  // %bb.1: // %A
+; CHECK-NEXT:    bl g0
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+; CHECK-NEXT:  .LBB1_2: // %B
+; CHECK-NEXT:    bl g1
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+;
+; CHECK-SVE2p1-LABEL: f_while_x2:
+; CHECK-SVE2p1:       // %bb.0: // %E
+; CHECK-SVE2p1-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SVE2p1-NEXT:    mov w8, w1
+; CHECK-SVE2p1-NEXT:    mov w9, w0
+; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x9, x8
+; CHECK-SVE2p1-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-SVE2p1-NEXT:    mov p0.b, p1.b
+; CHECK-SVE2p1-NEXT:    fmov w8, s0
+; CHECK-SVE2p1-NEXT:    tbz w8, #0, .LBB1_2
+; CHECK-SVE2p1-NEXT:  // %bb.1: // %A
+; CHECK-SVE2p1-NEXT:    bl g0
+; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT:    ret
+; CHECK-SVE2p1-NEXT:  .LBB1_2: // %B
+; CHECK-SVE2p1-NEXT:    bl g1
+; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT:    ret
+E:
+  %wide.mask = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i32 %i, i32 %n)
+  %mask.hi = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1(<vscale x 16 x i1> %wide.mask, i64 8)
+  %mask = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1(<vscale x 16 x i1> %wide.mask, i64 0)
+  %elt = extractelement <vscale x 8 x i1> %mask, i64 0
+  br i1 %elt, label %A, label %B
+A:
+  call void @g0(<vscale x 8 x i1> %mask.hi)
+  ret void
+B:
+  call void @g1(<vscale x 8 x i1> %mask.hi)
+  ret void
+}
+
+declare void @g0(...)
+declare void @g1(...)
+
+attributes #0 = { nounwind vscale_range(1,16) "target-cpu"="neoverse-v1" }

>From 399074941d31b23b72111716b8e781b77f10b36c Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 12 Jun 2024 10:25:45 +0100
Subject: [PATCH 2/2] [AArch64] Optimise test of the LSB of a paired whileCC
 instruction

Change-Id: Iefc0eb7e4b90715ae08c154dde5bda1091f9de07
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 63 ++++++++++++++-----
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  1 +
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp  | 52 ++++++++++-----
 llvm/lib/Target/AArch64/AArch64InstrInfo.h    |  3 +-
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  9 +--
 llvm/lib/Target/AArch64/SVEInstrFormats.td    | 17 +++--
 llvm/test/CodeGen/AArch64/opt-while-test.ll   | 26 +++-----
 7 files changed, 109 insertions(+), 62 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9f6f66e9e0c70..e8cf16e28b437 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2727,6 +2727,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::INSR)
     MAKE_CASE(AArch64ISD::PTEST)
     MAKE_CASE(AArch64ISD::PTEST_ANY)
+    MAKE_CASE(AArch64ISD::PTEST_FIRST)
     MAKE_CASE(AArch64ISD::PTRUE)
     MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO)
     MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO)
@@ -18733,21 +18734,41 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
                         AArch64CC::CondCode Cond);
 
 static bool isPredicateCCSettingOp(SDValue N) {
-  if ((N.getOpcode() == ISD::SETCC) ||
-      (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-       (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
-        // get_active_lane_mask is lowered to a whilelo instruction.
-        N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
+  if (N.getOpcode() == ISD::SETCC)
     return true;
 
-  return false;
+  if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+      isNullConstant(N.getOperand(1)))
+    N = N.getOperand(0);
+
+  if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+    return false;
+
+  switch (N.getConstantOperandVal(0)) {
+  default:
+    return false;
+  case Intrinsic::aarch64_sve_whilege_x2:
+  case Intrinsic::aarch64_sve_whilegt_x2:
+  case Intrinsic::aarch64_sve_whilehi_x2:
+  case Intrinsic::aarch64_sve_whilehs_x2:
+  case Intrinsic::aarch64_sve_whilele_x2:
+  case Intrinsic::aarch64_sve_whilelo_x2:
+  case Intrinsic::aarch64_sve_whilels_x2:
+  case Intrinsic::aarch64_sve_whilelt_x2:
+    if (N.getResNo() != 0)
+      return false;
+    [[fallthrough]];
+  case Intrinsic::aarch64_sve_whilege:
+  case Intrinsic::aarch64_sve_whilegt:
+  case Intrinsic::aarch64_sve_whilehi:
+  case Intrinsic::aarch64_sve_whilehs:
+  case Intrinsic::aarch64_sve_whilele:
+  case Intrinsic::aarch64_sve_whilelo:
+  case Intrinsic::aarch64_sve_whilels:
+  case Intrinsic::aarch64_sve_whilelt:
+  case Intrinsic::get_active_lane_mask:
+    return true;
+  }
 }
 
 // Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
@@ -20666,9 +20687,19 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
   }
 
   // Set condition code (CC) flags.
-  SDValue Test = DAG.getNode(
-      Cond == AArch64CC::ANY_ACTIVE ? AArch64ISD::PTEST_ANY : AArch64ISD::PTEST,
-      DL, MVT::Other, Pg, Op);
+  AArch64ISD::NodeType NT;
+  switch (Cond) {
+  default:
+    NT = AArch64ISD::PTEST;
+    break;
+  case AArch64CC::ANY_ACTIVE:
+    NT = AArch64ISD::PTEST_ANY;
+    break;
+  case AArch64CC::FIRST_ACTIVE:
+    NT = AArch64ISD::PTEST_FIRST;
+    break;
+  }
+  SDValue Test = DAG.getNode(NT, DL, MVT::Other, Pg, Op);
 
   // Convert CC to integer based on requested condition.
   // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 986f1b67ee513..cb1774e193aad 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -359,6 +359,7 @@ enum NodeType : unsigned {
   INSR,
   PTEST,
   PTEST_ANY,
+  PTEST_FIRST,
   PTRUE,
 
   CTTZ_ELTS,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 949e7699d070d..bad1f63c83da4 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1184,6 +1184,7 @@ bool AArch64InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg,
     break;
   case AArch64::PTEST_PP:
   case AArch64::PTEST_PP_ANY:
+  case AArch64::PTEST_PP_FIRST:
     SrcReg = MI.getOperand(0).getReg();
     SrcReg2 = MI.getOperand(1).getReg();
     // Not sure about the mask and value for now...
@@ -1355,12 +1356,25 @@ static bool areCFlagsAccessedBetweenInstrs(
   return false;
 }
 
-std::optional<unsigned>
+std::optional<std::pair<unsigned, MachineInstr *>>
 AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
                                       MachineInstr *Pred,
                                       const MachineRegisterInfo *MRI) const {
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
+
+  // Handle a COPY from the LSB of the results of paired WHILEcc instruction.
+  if ((PredOpcode == TargetOpcode::COPY &&
+       Pred->getOperand(1).getSubReg() == AArch64::psub0) ||
+      // Handle unpack of the LSB of the result of a WHILEcc instruction.
+      PredOpcode == AArch64::PUNPKLO_PP) {
+    MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (MI && isWhileOpcode(MI->getOpcode())) {
+      Pred = MI;
+      PredOpcode = MI->getOpcode();
+    }
+  }
+
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
@@ -1369,15 +1383,16 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     // instruction and the condition is "any" since WHILcc does an implicit
     // PTEST(ALL, PG) check and PG is always a subset of ALL.
     if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
-      return PredOpcode;
+      return std::make_pair(PredOpcode, Pred);
 
-    // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
-    // redundant since WHILE performs an implicit PTEST with an all active
-    // mask.
+    // For PTEST(PTRUE_ALL, WHILE), since WHILE performs an implicit PTEST
+    // with an all active mask, the PTEST is redundant if ether the element
+    // size matches or the PTEST condition is "first".
     if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
-        getElementSizeForOpcode(MaskOpcode) ==
-            getElementSizeForOpcode(PredOpcode))
-      return PredOpcode;
+        (PTest->getOpcode() == AArch64::PTEST_PP_FIRST ||
+         getElementSizeForOpcode(MaskOpcode) ==
+             getElementSizeForOpcode(PredOpcode)))
+      return std::make_pair(PredOpcode, Pred);
 
     return {};
   }
@@ -1388,7 +1403,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     // "any" since PG is always a subset of the governing predicate of the
     // ptest-like instruction.
     if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
-      return PredOpcode;
+      return std::make_pair(PredOpcode, Pred);
 
     // For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
     // the element size matches and either the PTEST_LIKE instruction uses
@@ -1398,7 +1413,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
             getElementSizeForOpcode(PredOpcode)) {
       auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
       if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
-        return PredOpcode;
+        return std::make_pair(PredOpcode, Pred);
     }
 
     // For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
@@ -1427,7 +1442,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
     if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
                                   PTest->getOpcode() == AArch64::PTEST_PP_ANY))
-      return PredOpcode;
+      return std::make_pair(PredOpcode, Pred);
 
     return {};
   }
@@ -1471,7 +1486,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     return {};
   }
 
-  return convertToFlagSettingOpc(PredOpcode);
+  return std::make_pair(convertToFlagSettingOpc(PredOpcode), Pred);
 }
 
 /// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
@@ -1481,10 +1496,12 @@ bool AArch64InstrInfo::optimizePTestInstr(
     const MachineRegisterInfo *MRI) const {
   auto *Mask = MRI->getUniqueVRegDef(MaskReg);
   auto *Pred = MRI->getUniqueVRegDef(PredReg);
+  unsigned NewOp;
   unsigned PredOpcode = Pred->getOpcode();
-  auto NewOp = canRemovePTestInstr(PTest, Mask, Pred, MRI);
-  if (!NewOp)
+  auto canRemove = canRemovePTestInstr(PTest, Mask, Pred, MRI);
+  if (!canRemove)
     return false;
+  std::tie(NewOp, Pred) = *canRemove;
 
   const TargetRegisterInfo *TRI = &getRegisterInfo();
 
@@ -1498,8 +1515,8 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
   PTest->eraseFromParent();
-  if (*NewOp != PredOpcode) {
-    Pred->setDesc(get(*NewOp));
+  if (NewOp != PredOpcode) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
@@ -1560,7 +1577,8 @@ bool AArch64InstrInfo::optimizeCompareInstr(
   }
 
   if (CmpInstr.getOpcode() == AArch64::PTEST_PP ||
-      CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY)
+      CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY ||
+      CmpInstr.getOpcode() == AArch64::PTEST_PP_FIRST)
     return optimizePTestInstr(&CmpInstr, SrcReg, SrcReg2, MRI);
 
   if (SrcReg2 != 0)
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 792e0c3063b10..d722f433a150b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -572,7 +572,8 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
   bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
                           unsigned PredReg,
                           const MachineRegisterInfo *MRI) const;
-  std::optional<unsigned>
+
+  std::optional<std::pair<unsigned, MachineInstr *>>
   canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
                       MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
 };
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index bd5de628d8529..3cee3e92fae08 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -373,9 +373,10 @@ def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
      (AArch64fadda_p_node (SVEAllActive), node:$op2,
              (vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;
 
-def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
-def AArch64ptest     : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
-def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
+def SDT_AArch64PTest   : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
+def AArch64ptest       : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
+def AArch64ptest_any   : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
+def AArch64ptest_first : SDNode<"AArch64ISD::PTEST_FIRST", SDT_AArch64PTest>;
 
 def SDT_AArch64DUP_PRED  : SDTypeProfile<1, 3,
   [SDTCisVec<0>, SDTCisSameAs<0, 3>, SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0, 1>]>;
@@ -948,7 +949,7 @@ let Predicates = [HasSVEorSME] in {
   defm BRKB_PPmP  : sve_int_break_m<0b101, "brkb",  int_aarch64_sve_brkb>;
   defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>;
 
-  defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any>;
+  defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any, AArch64ptest_first>;
   defm PFALSE   : sve_int_pfalse<0b000000, "pfalse">;
   defm PFIRST   : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>;
   defm PNEXT    : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index fc7d3cdda4acd..1c3528bed08c4 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -784,13 +784,16 @@ class sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op>
 }
 
 multiclass sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op,
-                         SDPatternOperator op_any> {
+                         SDPatternOperator op_any, SDPatternOperator op_first> {
   def NAME : sve_int_ptest<opc, asm, op>;
 
   let hasNoSchedulingInfo = 1, isCompare = 1, Defs = [NZCV] in {
   def _ANY : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
                     [(op_any (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
              PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
+  def _FIRST : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
+                    [(op_first (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
+             PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
   }
 }
 
@@ -9669,7 +9672,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
 
 // SVE integer compare scalar count and limit (predicate pair)
 class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
-                             RegisterOperand ppr_ty>
+                             RegisterOperand ppr_ty, ElementSizeEnum EltSz>
     : I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
         mnemonic, "\t$Pd, $Rn, $Rm",
         "", []>, Sched<[]> {
@@ -9687,16 +9690,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
   let Inst{3-1}   = Pd;
   let Inst{0}     = opc{0};
 
+  let ElementSize = EltSz;
   let Defs = [NZCV];
   let hasSideEffects = 0;
+  let isWhile = 1;
 }
 
 
 multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
- def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
- def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
- def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
- def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
+ def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
+ def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
+ def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
+ def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
 }
 
 
diff --git a/llvm/test/CodeGen/AArch64/opt-while-test.ll b/llvm/test/CodeGen/AArch64/opt-while-test.ll
index 3a43f817d6154..a022f4d8c9e23 100644
--- a/llvm/test/CodeGen/AArch64/opt-while-test.ll
+++ b/llvm/test/CodeGen/AArch64/opt-while-test.ll
@@ -8,10 +8,7 @@ define void @f_while(i32 %i, i32 %n) #0 {
 ; CHECK:       // %bb.0: // %E
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
 ; CHECK-NEXT:    whilelo p0.b, w0, w1
-; CHECK-NEXT:    punpklo p0.h, p0.b
-; CHECK-NEXT:    mov z0.h, p0/z, #1 // =0x1
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    tbz w8, #0, .LBB0_2
+; CHECK-NEXT:    b.pl .LBB0_2
 ; CHECK-NEXT:  // %bb.1: // %A
 ; CHECK-NEXT:    bl g0
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -25,10 +22,7 @@ define void @f_while(i32 %i, i32 %n) #0 {
 ; CHECK-SVE2p1:       // %bb.0: // %E
 ; CHECK-SVE2p1-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
 ; CHECK-SVE2p1-NEXT:    whilelo p0.b, w0, w1
-; CHECK-SVE2p1-NEXT:    punpklo p0.h, p0.b
-; CHECK-SVE2p1-NEXT:    mov z0.h, p0/z, #1 // =0x1
-; CHECK-SVE2p1-NEXT:    fmov w8, s0
-; CHECK-SVE2p1-NEXT:    tbz w8, #0, .LBB0_2
+; CHECK-SVE2p1-NEXT:    b.pl .LBB0_2
 ; CHECK-SVE2p1-NEXT:  // %bb.1: // %A
 ; CHECK-SVE2p1-NEXT:    bl g0
 ; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -54,12 +48,9 @@ define void @f_while_x2(i32 %i, i32 %n) #0 {
 ; CHECK-LABEL: f_while_x2:
 ; CHECK:       // %bb.0: // %E
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT:    whilelo p1.b, w0, w1
-; CHECK-NEXT:    punpkhi p0.h, p1.b
-; CHECK-NEXT:    punpklo p1.h, p1.b
-; CHECK-NEXT:    mov z0.h, p1/z, #1 // =0x1
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    tbz w8, #0, .LBB1_2
+; CHECK-NEXT:    whilelo p0.b, w0, w1
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    b.pl .LBB1_2
 ; CHECK-NEXT:  // %bb.1: // %A
 ; CHECK-NEXT:    bl g0
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -75,15 +66,14 @@ define void @f_while_x2(i32 %i, i32 %n) #0 {
 ; CHECK-SVE2p1-NEXT:    mov w8, w1
 ; CHECK-SVE2p1-NEXT:    mov w9, w0
 ; CHECK-SVE2p1-NEXT:    whilelo { p0.h, p1.h }, x9, x8
-; CHECK-SVE2p1-NEXT:    mov z0.h, p0/z, #1 // =0x1
-; CHECK-SVE2p1-NEXT:    mov p0.b, p1.b
-; CHECK-SVE2p1-NEXT:    fmov w8, s0
-; CHECK-SVE2p1-NEXT:    tbz w8, #0, .LBB1_2
+; CHECK-SVE2p1-NEXT:    b.pl .LBB1_2
 ; CHECK-SVE2p1-NEXT:  // %bb.1: // %A
+; CHECK-SVE2p1-NEXT:    mov p0.b, p1.b
 ; CHECK-SVE2p1-NEXT:    bl g0
 ; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-SVE2p1-NEXT:    ret
 ; CHECK-SVE2p1-NEXT:  .LBB1_2: // %B
+; CHECK-SVE2p1-NEXT:    mov p0.b, p1.b
 ; CHECK-SVE2p1-NEXT:    bl g1
 ; CHECK-SVE2p1-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-SVE2p1-NEXT:    ret



More information about the llvm-commits mailing list