[llvm] [AArch64] Refactor redundant PTEST optimisations (NFC) (PR #87802)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 5 09:18:56 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>

This patch refactors  `AArch64InstrInfo::optimizePTestInstr` to simplify the convoluted conditions and control flow
and make it easier to add the optimisation in https://github.com/llvm/llvm-project/pull/81141

---

Patch is 474.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87802.diff


28 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+10) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2) 
- (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+2) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+59-2) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+92-78) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.h (+3) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+9) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+8) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+6-1) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.cpp (+4-3) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+47-3) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (+29-5) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+86-21) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+8-10) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1) 
- (modified) llvm/test/CodeGen/AArch64/active_lane_mask.ll (+1) 
- (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+395) 
- (added) llvm/test/CodeGen/AArch64/sve-wide-lane-mask.ll (+1069) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll (+40-55) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/pr73894.ll (+2-2) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll (+842-830) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding-unroll.ll (+165-157) 
- (added) llvm/test/Transforms/LoopVectorize/AArch64/sve-wide-lane-mask.ll (+656) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll (+48-54) 
- (modified) llvm/test/Transforms/LoopVectorize/ARM/tail-folding-prefer-flag.ll (+3-3) 
- (modified) llvm/test/Transforms/LoopVectorize/strict-fadd-interleave-only.ll (+2-2) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index fa9392b86c15b9..eae993e60008e4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1241,6 +1241,8 @@ class TargetTransformInfo {
   /// and the number of execution units in the CPU.
   unsigned getMaxInterleaveFactor(ElementCount VF) const;
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   /// Collect properties of V used in cost analysis, e.g. OP_PowerOf2.
   static OperandValueInfo getOperandInfo(const Value *V);
 
@@ -1999,6 +2001,9 @@ class TargetTransformInfo::Concept {
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
+
+  virtual ElementCount getMaxPredicateLength(ElementCount VF) const = 0;
+
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       OperandValueInfo Opd1Info, OperandValueInfo Opd2Info,
@@ -2622,6 +2627,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
+
+  ElementCount getMaxPredicateLength(ElementCount VF) const override {
+    return Impl.getMaxPredicateLength(VF);
+  }
+
   unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
                                             unsigned &JTSize,
                                             ProfileSummaryInfo *PSI,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 63c2ef8912b29c..1c4f2f963e89f3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -531,6 +531,8 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info, TTI::OperandValueInfo Opd2Info,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 42d8f74fd427fb..297068943a63c6 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -890,6 +890,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 5f933b4587843c..39bf57737890c8 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -816,6 +816,10 @@ unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
 
+ElementCount TargetTransformInfo::getMaxPredicateLength(ElementCount VF) const {
+  return TTIImpl->getMaxPredicateLength(VF);
+}
+
 TargetTransformInfo::OperandValueInfo
 TargetTransformInfo::getOperandInfo(const Value *V) {
   OperandValueKind OpInfo = OK_AnyValue;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8218960406ec13..d0afa0a5c65330 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1873,8 +1873,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
@@ -20507,6 +20507,61 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   return SDValue();
 }
 
+static SDValue tryCombineWhileLo(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 const AArch64Subtarget *Subtarget) {
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+    return SDValue();
+
+  if (!N->hasNUsesOfValue(2, 0))
+    return SDValue();
+
+  const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
+  if (HalfSize < 2)
+    return SDValue();
+
+  auto It = N->use_begin();
+  SDNode *Lo = *It++;
+  SDNode *Hi = *It;
+
+  uint64_t OffLo, OffHi;
+  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
+      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Hi->getOperand(1).getNode(), OffHi))
+    return SDValue();
+
+  if (OffLo > OffHi) {
+    std::swap(Lo, Hi);
+    std::swap(OffLo, OffHi);
+  }
+
+  if (OffLo != 0 || OffHi != HalfSize)
+    return SDValue();
+
+  SelectionDAG &DAG = DCI.DAG;
+  SDLoc DL(N);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+  SDValue Idx = N->getOperand(1);
+  SDValue TC = N->getOperand(2);
+  if (Idx.getValueType() != MVT::i64) {
+    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  }
+  auto R =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+  DCI.CombineTo(Lo, R.getValue(0));
+  DCI.CombineTo(Hi, R.getValue(1));
+
+  return SDValue(N, 0);
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -20837,6 +20892,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_ptest_last:
     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
                     AArch64CC::LAST_ACTIVE);
+  case Intrinsic::aarch64_sve_whilelo:
+    return tryCombineWhileLo(N, DCI, Subtarget);
   }
   return SDValue();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 22687b0e31c284..1b422969379d25 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1351,48 +1351,52 @@ static bool areCFlagsAccessedBetweenInstrs(
   return false;
 }
 
-/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
-/// operation which could set the flags in an identical manner
-bool AArch64InstrInfo::optimizePTestInstr(
-    MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
-    const MachineRegisterInfo *MRI) const {
-  auto *Mask = MRI->getUniqueVRegDef(MaskReg);
-  auto *Pred = MRI->getUniqueVRegDef(PredReg);
-  auto NewOp = Pred->getOpcode();
-  bool OpChanged = false;
-
+std::pair<bool, unsigned>
+AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+                                      MachineInstr *Pred,
+                                      const MachineRegisterInfo *MRI) const {
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
-  if (isPTrueOpcode(MaskOpcode) && (PredIsPTestLike || PredIsWhileLike) &&
-      getElementSizeForOpcode(MaskOpcode) ==
-          getElementSizeForOpcode(PredOpcode) &&
-      Mask->getOperand(1).getImm() == 31) {
+  if (PredIsWhileLike) {
+    // For PTEST(PG, PG), PTEST is redundant when PG is the result of a WHILEcc
+    // instruction and the condition is "any" since WHILcc does an implicit
+    // PTEST(ALL, PG) check and PG is always a subset of ALL.
+    if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+      return {true, 0};
+
     // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
     // redundant since WHILE performs an implicit PTEST with an all active
-    // mask. Must be an all active predicate of matching element size.
+    // mask.
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+        getElementSizeForOpcode(MaskOpcode) ==
+            getElementSizeForOpcode(PredOpcode))
+      return {true, 0};
+
+    return {false, 0};
+  }
+
+  if (PredIsPTestLike) {
+    // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
+    // instruction that sets the flags as PTEST would and the condition is
+    // "any" since PG is always a subset of the governing predicate of the
+    // ptest-like instruction.
+    if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+      return {true, 0};
 
     // For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
-    // PTEST_LIKE instruction uses the same all active mask and the element
-    // size matches. If the PTEST has a condition of any then it is always
-    // redundant.
-    if (PredIsPTestLike) {
+    // the element size matches and either the PTEST_LIKE instruction uses
+    // the same all active mask or the condition is "any".
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+        getElementSizeForOpcode(MaskOpcode) ==
+            getElementSizeForOpcode(PredOpcode)) {
       auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
-      if (Mask != PTestLikeMask && PTest->getOpcode() != AArch64::PTEST_PP_ANY)
-        return false;
+      if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+        return {true, 0};
     }
 
-    // Fallthough to simply remove the PTEST.
-  } else if ((Mask == Pred) && (PredIsPTestLike || PredIsWhileLike) &&
-             PTest->getOpcode() == AArch64::PTEST_PP_ANY) {
-    // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
-    // instruction that sets the flags as PTEST would. This is only valid when
-    // the condition is any.
-
-    // Fallthough to simply remove the PTEST.
-  } else if (PredIsPTestLike) {
     // For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
     // flags are set based on the same mask 'PG', but PTEST_LIKE must operate
     // on 8-bit predicates like the PTEST.  Otherwise, for instructions like
@@ -1417,55 +1421,65 @@ bool AArch64InstrInfo::optimizePTestInstr(
     // identical regardless of element size.
     auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
     uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
-    if ((Mask != PTestLikeMask) ||
-        (PredElementSize != AArch64::ElementSizeB &&
-         PTest->getOpcode() != AArch64::PTEST_PP_ANY))
-      return false;
+    if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
+                                  PTest->getOpcode() == AArch64::PTEST_PP_ANY))
+      return {true, 0};
 
-    // Fallthough to simply remove the PTEST.
-  } else {
-    // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
-    // opcode so the PTEST becomes redundant.
-    switch (PredOpcode) {
-    case AArch64::AND_PPzPP:
-    case AArch64::BIC_PPzPP:
-    case AArch64::EOR_PPzPP:
-    case AArch64::NAND_PPzPP:
-    case AArch64::NOR_PPzPP:
-    case AArch64::ORN_PPzPP:
-    case AArch64::ORR_PPzPP:
-    case AArch64::BRKA_PPzP:
-    case AArch64::BRKPA_PPzPP:
-    case AArch64::BRKB_PPzP:
-    case AArch64::BRKPB_PPzPP:
-    case AArch64::RDFFR_PPz: {
-      // Check to see if our mask is the same. If not the resulting flag bits
-      // may be different and we can't remove the ptest.
-      auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
-      if (Mask != PredMask)
-        return false;
-      break;
-    }
-    case AArch64::BRKN_PPzP: {
-      // BRKN uses an all active implicit mask to set flags unlike the other
-      // flag-setting instructions.
-      // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
-      if ((MaskOpcode != AArch64::PTRUE_B) ||
-          (Mask->getOperand(1).getImm() != 31))
-        return false;
-      break;
-    }
-    case AArch64::PTRUE_B:
-      // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
-      break;
-    default:
-      // Bail out if we don't recognize the input
-      return false;
-    }
+    return {false, 0};
+  }
 
-    NewOp = convertToFlagSettingOpc(PredOpcode);
-    OpChanged = true;
+  // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
+  // opcode so the PTEST becomes redundant.
+  switch (PredOpcode) {
+  case AArch64::AND_PPzPP:
+  case AArch64::BIC_PPzPP:
+  case AArch64::EOR_PPzPP:
+  case AArch64::NAND_PPzPP:
+  case AArch64::NOR_PPzPP:
+  case AArch64::ORN_PPzPP:
+  case AArch64::ORR_PPzPP:
+  case AArch64::BRKA_PPzP:
+  case AArch64::BRKPA_PPzPP:
+  case AArch64::BRKB_PPzP:
+  case AArch64::BRKPB_PPzPP:
+  case AArch64::RDFFR_PPz: {
+    // Check to see if our mask is the same. If not the resulting flag bits
+    // may be different and we can't remove the ptest.
+    auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (Mask != PredMask)
+      return {false, 0};
+    break;
   }
+  case AArch64::BRKN_PPzP: {
+    // BRKN uses an all active implicit mask to set flags unlike the other
+    // flag-setting instructions.
+    // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
+    if ((MaskOpcode != AArch64::PTRUE_B) ||
+        (Mask->getOperand(1).getImm() != 31))
+      return {false, 0};
+    break;
+  }
+  case AArch64::PTRUE_B:
+    // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
+    break;
+  default:
+    // Bail out if we don't recognize the input
+    return {false, 0};
+  }
+
+  return {true, convertToFlagSettingOpc(PredOpcode)};
+}
+
+/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
+/// operation which could set the flags in an identical manner
+bool AArch64InstrInfo::optimizePTestInstr(
+    MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
+    const MachineRegisterInfo *MRI) const {
+  auto *Mask = MRI->getUniqueVRegDef(MaskReg);
+  auto *Pred = MRI->getUniqueVRegDef(PredReg);
+  auto [canRemove, NewOp] = canRemovePTestInstr(PTest, Mask, Pred, MRI);
+  if (!canRemove)
+    return false;
 
   const TargetRegisterInfo *TRI = &getRegisterInfo();
 
@@ -1478,9 +1492,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // as they are prior to PTEST. Sometimes this requires the tested PTEST
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
-  Pred->setDesc(get(NewOp));
   PTest->eraseFromParent();
-  if (OpChanged) {
+  if (NewOp) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 2f10f80f4bdf70..7cc770a5b4eb49 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -432,6 +432,9 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
   bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
                           unsigned PredReg,
                           const MachineRegisterInfo *MRI) const;
+  std::pair<bool, unsigned>
+  canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+                      MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
 };
 
 struct UsedNZCV {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ee7137b92445bb..66498817f55f73 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3362,6 +3362,15 @@ unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
 
+ElementCount AArch64TTIImpl::getMaxPredicateLength(ElementCount VF) const {
+  // Do not create masks bigger than `<vscale x 16 x i1>`.
+  unsigned N = ST->hasSVE() ? 16 : 0;
+  // Do not create masks that are more than twice the VF.
+  N = std::min(N, 2 * VF.getKnownMinValue());
+  return VF.isScalable() ? ElementCount::getScalable(N)
+                         : ElementCount::getFixed(N);
+}
+
 // For Falkor, we want to avoid having too many strided loads in a loop since
 // that can exhaust the HW prefetcher resources.  We adjust the unroller
 // MaxCount preference below to attempt to ensure unrolling doesn't create too
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index de39dea2be43e1..6501cc4a85e8d3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -157,6 +157,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index ece2a34f180cb4..a4e4bd8c2bb4b2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -196,6 +196,14 @@ class VPBuilder {
   VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
                       DebugLoc DL = {}, const Twine &Name = "");
 
+  VPValue *createGetActiveLaneMask(VPValue *IV, VPValue *TC, DebugLoc DL,
+                                   const Twine &Name = "") {
+    auto *ALM = new VPActiveLaneMaskRecipe(IV, TC, DL, Name);
+    if (BB)
+      BB->insert(ALM, InsertPt);
+    return ALM;
+  }
+
   //===--------------------------------------------------------------------===//
   // RAII helpers.
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 49bacb5ae6cc4e..6cf139775475c6 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -597,6 +597,10 @@ class InnerLoopVectorizer {
   /// count of the original loop for both main loop and epilogue vectorization.
   void setTripCount(Value *TC) { TripCount = TC; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const {
+    return TTI->getMaxPredicateLength(VF);
+  }
+
 protected:
   friend class LoopVectorizationPlanner;
 
@@ -7525,7 +7529,8 @@ LoopVectorizationPlanner::executePlan(
   LLVM_DEBUG(BestVPlan.dump());
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+  VPTransformState State(BestVF, BestUF, TTI.getMaxPredicateLength(BestVF), LI,
+                         DT, ILV.Builder, &ILV, &BestVPlan,
                          OrigLoop->getHeader()->getContext());
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 8ebd75da346546..27ff4c884cc566 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -214,12 +214,13 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
   return It;
 }
 
-VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
+VPTransformState::VPTransformState(ElementCount VF, unsigned UF,
+                                   ElementCount MaxPred, LoopInfo *LI,
                                    DominatorTree *DT, IRBui...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/87802


More information about the llvm-commits mailing list