[llvm] [SVE] Wide active lane mask (PR #76514)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 28 09:09:49 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>

These patches make the LoopVectorize generate lane masks longer than
the VF to allow the target to better utilise the instruction set.
The vectoriser emits one or more wide `llvm.get.active.lane.mask.*`
calls plus several `llvm.vector.extract.*` calls to yield the required
number of VF-wide masks.

The motivating example is a vectorised loop with unroll factor 2 that
can use the SVE2.1 `whilelo` instruction with predicate pair result, or
a SVE `whilelo` instruction with smaller element size plus
`punpklo`/`punpkhi`.

How wide is the lane mask that the vectoriser emits is controlled
by a TargetTransformInfo hook `getMaxPredicateLength`.The default
impementation (return the same length as the VF) keeps the
change non-functional for targets that can't or are not prepared
to handle wider lane masks.


---

Patch is 426.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76514.diff


22 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+10) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2) 
- (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+2) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+105-33) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+13-2) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+8) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2) 
- (modified) llvm/lib/Target/AArch64/SVEInstrFormats.td (+7-5) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+8) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+6-1) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+34-5) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+70-14) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+7-11) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1) 
- (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-32x1.ll (+31) 
- (added) llvm/test/CodeGen/AArch64/sve-wide-lane-mask.ll (+1003) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll (+886-874) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding-unroll.ll (+196-188) 
- (added) llvm/test/Transforms/LoopVectorize/AArch64/sve-wide-lane-mask.ll (+655) 
- (modified) llvm/test/Transforms/LoopVectorize/ARM/tail-folding-prefer-flag.ll (+3-3) 
- (modified) llvm/test/Transforms/LoopVectorize/strict-fadd-interleave-only.ll (+10-10) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 048912beaba5a1..8416f6138c6c46 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1216,6 +1216,8 @@ class TargetTransformInfo {
   /// and the number of execution units in the CPU.
   unsigned getMaxInterleaveFactor(ElementCount VF) const;
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   /// Collect properties of V used in cost analysis, e.g. OP_PowerOf2.
   static OperandValueInfo getOperandInfo(const Value *V);
 
@@ -1952,6 +1954,9 @@ class TargetTransformInfo::Concept {
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
+
+  virtual ElementCount getMaxPredicateLength(ElementCount VF) const = 0;
+
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       OperandValueInfo Opd1Info, OperandValueInfo Opd2Info,
@@ -2557,6 +2562,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
+
+  ElementCount getMaxPredicateLength(ElementCount VF) const override {
+    return Impl.getMaxPredicateLength(VF);
+  }
+
   unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
                                             unsigned &JTSize,
                                             ProfileSummaryInfo *PSI,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7ad3ce512a3552..e341220fa48b09 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -513,6 +513,8 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info, TTI::OperandValueInfo Opd2Info,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 5e7bdcdf72a49f..fcad6a86538256 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -881,6 +881,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 67246afa23147a..aaf39be8ff0634 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -789,6 +789,10 @@ unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
 
+ElementCount TargetTransformInfo::getMaxPredicateLength(ElementCount VF) const {
+  return TTIImpl->getMaxPredicateLength(VF);
+}
+
 TargetTransformInfo::OperandValueInfo
 TargetTransformInfo::getOperandInfo(const Value *V) {
   OperandValueKind OpInfo = OK_AnyValue;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dffe69bdb900db..bf5d48c903307f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1772,15 +1772,17 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
   // whilelo instruction for generating fixed-width predicates too.
   if (ResVT != MVT::nxv2i1 && ResVT != MVT::nxv4i1 && ResVT != MVT::nxv8i1 &&
       ResVT != MVT::nxv16i1 && ResVT != MVT::v2i1 && ResVT != MVT::v4i1 &&
-      ResVT != MVT::v8i1 && ResVT != MVT::v16i1)
+      ResVT != MVT::v8i1 && ResVT != MVT::v16i1 &&
+      (!(Subtarget->hasSVE2p1() || Subtarget->hasSME2()) ||
+       ResVT != MVT::nxv32i1)) // TODO: handle MVT::v32i1
     return true;
 
   // The whilelo instruction only works with i32 or i64 scalar inputs.
@@ -17760,22 +17762,49 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
                         AArch64CC::CondCode Cond);
 
-static bool isPredicateCCSettingOp(SDValue N) {
-  if ((N.getOpcode() == ISD::SETCC) ||
-      (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-       (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
-        N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
-        // get_active_lane_mask is lowered to a whilelo instruction.
-        N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
-    return true;
+static SDValue getPredicateCCSettingOp(SDValue N) {
+  if (N.getOpcode() == ISD::SETCC) {
+    EVT VT = N.getValueType();
+    return VT.isScalableVector() && VT.getVectorElementType() == MVT::i1
+               ? N
+               : SDValue();
+  }
 
-  return false;
+  if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+      isNullConstant(N.getOperand(1)))
+    N = N.getOperand(0);
+
+  if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+    return SDValue();
+
+  switch (N.getConstantOperandVal(0)) {
+  default:
+    return SDValue();
+  case Intrinsic::aarch64_sve_whilege_x2:
+  case Intrinsic::aarch64_sve_whilegt_x2:
+  case Intrinsic::aarch64_sve_whilehi_x2:
+  case Intrinsic::aarch64_sve_whilehs_x2:
+  case Intrinsic::aarch64_sve_whilele_x2:
+  case Intrinsic::aarch64_sve_whilelo_x2:
+  case Intrinsic::aarch64_sve_whilels_x2:
+  case Intrinsic::aarch64_sve_whilelt_x2:
+    if (N.getResNo() != 0)
+      return SDValue();
+    [[fallthrough]];
+  case Intrinsic::aarch64_sve_whilege:
+  case Intrinsic::aarch64_sve_whilegt:
+  case Intrinsic::aarch64_sve_whilehi:
+  case Intrinsic::aarch64_sve_whilehs:
+  case Intrinsic::aarch64_sve_whilele:
+  case Intrinsic::aarch64_sve_whilelo:
+  case Intrinsic::aarch64_sve_whilels:
+  case Intrinsic::aarch64_sve_whilelt:
+  case Intrinsic::get_active_lane_mask:
+    assert(N.getValueType().isScalableVector() &&
+           N.getValueType().getVectorElementType() == MVT::i1 &&
+           "Intrinsic expected to yield scalable i1 vector");
+    return N;
+  }
 }
 
 // Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
@@ -17789,21 +17818,17 @@ performFirstTrueTestVectorCombine(SDNode *N,
   if (!Subtarget->hasSVE() || DCI.isBeforeLegalize())
     return SDValue();
 
-  SDValue N0 = N->getOperand(0);
-  EVT VT = N0.getValueType();
-
-  if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1 ||
-      !isNullConstant(N->getOperand(1)))
-    return SDValue();
-
-  // Restricted the DAG combine to only cases where we're extracting from a
-  // flag-setting operation.
-  if (!isPredicateCCSettingOp(N0))
+  // Restrict the DAG combine to only cases where we're extracting the zero-th
+  // element from the result of a flag-setting operation.
+  SDValue N0;
+  if (!isNullConstant(N->getOperand(1)) ||
+      !(N0 = getPredicateCCSettingOp(N->getOperand(0))))
     return SDValue();
 
   // Extracts of lane 0 for SVE can be expressed as PTEST(Op, FIRST) ? 1 : 0
   SelectionDAG &DAG = DCI.DAG;
-  SDValue Pg = getPTrue(DAG, SDLoc(N), VT, AArch64SVEPredPattern::all);
+  SDValue Pg =
+      getPTrue(DAG, SDLoc(N), N0.getValueType(), AArch64SVEPredPattern::all);
   return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::FIRST_ACTIVE);
 }
 
@@ -19768,7 +19793,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   default:
     break;
   case Intrinsic::get_active_lane_mask: {
-    SDValue Res = SDValue();
     EVT VT = N->getValueType(0);
     if (VT.isFixedLengthVector()) {
       // We can use the SVE whilelo instruction to lower this intrinsic by
@@ -19791,15 +19815,63 @@ static SDValue performIntrinsicCombine(SDNode *N,
           EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
                            VT.getVectorElementCount());
 
-      Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
-                        N->getOperand(1), N->getOperand(2));
+      SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
+                                N->getOperand(1), N->getOperand(2));
       Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
       Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
                         DAG.getConstant(0, DL, MVT::i64));
       Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
+
+      return Res;
     }
-    return Res;
+
+    if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+      return SDValue();
+
+    if (!N->hasNUsesOfValue(2, 0))
+      return SDValue();
+
+    auto It = N->use_begin();
+    SDNode *Lo = *It++;
+    SDNode *Hi = *It;
+
+    const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
+    uint64_t OffLo, OffHi;
+    if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+        Lo->getOperand(1)->getOpcode() != ISD::Constant ||
+        ((OffLo = Lo->getConstantOperandVal(1)) != 0 && OffLo != HalfSize) ||
+        Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+        Hi->getOperand(1)->getOpcode() != ISD::Constant ||
+        ((OffHi = Hi->getConstantOperandVal(1)) != 0 && OffHi != HalfSize))
+      return SDValue();
+
+    if (OffLo > OffHi) {
+      std::swap(Lo, Hi);
+      std::swap(OffLo, OffHi);
+    }
+
+    if (OffLo != 0 || OffHi != HalfSize)
+      return SDValue();
+
+    SDLoc DL(N);
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+    SDValue Idx = N->getOperand(1);
+    SDValue TC = N->getOperand(2);
+    if (Idx.getValueType() != MVT::i64) {
+      Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+      TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+    }
+    auto R =
+        DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                    {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+    DCI.CombineTo(Lo, R.getValue(0));
+    DCI.CombineTo(Hi, R.getValue(1));
+
+    return SDValue(N, 0);
   }
+
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 1cfbf4737a6f72..91a9af727ffb0b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1356,11 +1356,22 @@ bool AArch64InstrInfo::optimizePTestInstr(
     const MachineRegisterInfo *MRI) const {
   auto *Mask = MRI->getUniqueVRegDef(MaskReg);
   auto *Pred = MRI->getUniqueVRegDef(PredReg);
-  auto NewOp = Pred->getOpcode();
+  unsigned NewOp;
   bool OpChanged = false;
 
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
+
+  // Handle a COPY from the LSB of a paired WHILEcc instruction.
+  if ((PredOpcode == TargetOpcode::COPY &&
+       Pred->getOperand(1).getSubReg() == AArch64::psub0)) {
+    MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (MI && isWhileOpcode(MI->getOpcode())) {
+      Pred = MI;
+      PredOpcode = MI->getOpcode();
+    }
+  }
+
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
@@ -1476,9 +1487,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // as they are prior to PTEST. Sometimes this requires the tested PTEST
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
-  Pred->setDesc(get(NewOp));
   PTest->eraseFromParent();
   if (OpChanged) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index b5b8b68291786d..58d7be9979fba8 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3223,6 +3223,14 @@ unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
 
+ElementCount AArch64TTIImpl::getMaxPredicateLength(ElementCount VF) const {
+  // Do not create masks that are more than twice the VF.
+  unsigned N = ST->hasSVE2p1() ? 32 : ST->hasSVE() ? 16 : 0;
+  N = std::min(N, 2 * VF.getKnownMinValue());
+  return VF.isScalable() ? ElementCount::getScalable(N)
+                         : ElementCount::getFixed(N);
+}
+
 // For Falkor, we want to avoid having too many strided loads in a loop since
 // that can exhaust the HW prefetcher resources.  We adjust the unroller
 // MaxCount preference below to attempt to ensure unrolling doesn't create too
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 0b220069a388b6..077320a7a23715 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -157,6 +157,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index b7552541e950d9..963b62988f7001 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -9754,7 +9754,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
 
 // SVE integer compare scalar count and limit (predicate pair)
 class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
-                             RegisterOperand ppr_ty>
+                             RegisterOperand ppr_ty, ElementSizeEnum EltSz>
     : I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
         mnemonic, "\t$Pd, $Rn, $Rm",
         "", []>, Sched<[]> {
@@ -9772,16 +9772,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
   let Inst{3-1}   = Pd;
   let Inst{0}     = opc{0};
 
+  let ElementSize = EltSz;
   let Defs = [NZCV];
   let hasSideEffects = 0;
+  let isWhile = 1;
 }
 
 
 multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
- def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
- def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
- def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
- def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
+ def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
+ def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
+ def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
+ def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
 }
 
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 577ce8000de27b..ff19245b429eff 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -178,6 +178,14 @@ class VPBuilder {
   VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
                       DebugLoc DL = {}, const Twine &Name = "");
 
+  VPValue *createGetActiveLaneMask(VPValue *IV, VPValue *TC, DebugLoc DL,
+                                   const Twine &Name = "") {
+    auto *ALM = new VPActiveLaneMaskRecipe(IV, TC, DL, Name);
+    if (BB)
+      BB->insert(ALM, InsertPt);
+    return ALM;
+  }
+
   //===--------------------------------------------------------------------===//
   // RAII helpers.
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f82e161fb846d1..48c87b9bb8c554 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -608,6 +608,10 @@ class InnerLoopVectorizer {
   /// count of the original loop for both main loop and epilogue vectorization.
   void setTripCount(Value *TC) { TripCount = TC; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const {
+    return TTI->getMaxPredicateLength(VF);
+  }
+
 protected:
   friend class LoopVectorizationPlanner;
 
@@ -7604,7 +7608,8 @@ SCEV2ValueTy LoopVectorizationPlanner::executePlan(
     VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE);
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+  VPTransformState State(BestVF, BestUF, TTI.getMaxPredicateLength(BestVF), LI,
+                         DT, ILV.Builder, &ILV, &BestVPlan,
                          OrigLoop->getHeader()->getContext());
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 94cb7688981361..dd761aacd11fda 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -232,15 +232,16 @@ struct VPIteration {
 /// VPTransformState holds information passed down when "executing" a VPlan,
 /// needed for generating the output IR.
 struct VPTransformState {
-  VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
-                   DominatorTree *DT, IRBuilderBase &Builder,
+  VPTransformState(ElementCount VF, unsigned UF, ElementCount MaxPred,
+                   LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder,
                    InnerLoopVectorizer *ILV, VPlan *Plan, LLVMContext &Ctx)
-      : VF(VF), UF(UF), LI(LI), DT(DT), Builder(Builder), ILV(ILV), Plan(Plan),
-        LVer(nullptr), TypeAnalysis(Ctx) {}
+      : VF(VF), UF(UF), MaxPred(MaxPred), LI(LI), DT(DT), Builder(Builder),
+        ILV(ILV), Plan(Plan), LVer(nullptr), TypeAnalysis(Ctx) {}
 
   /// The chosen Vectorization and Unroll Factors of the loop being vectorized.
   ElementCount VF;
   unsigned UF;
+  ElementCount MaxPred;
 
   /// Hold the indices to generate specific scalar instructions. Null indicates
   /// that all instances are to be generated, using either scalar or vector
@@ -1164,7 +1165,6 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
     switch (getOpcode()) {
     default:
       return false;
-    case VPInstruction::ActiveLaneMask:
     case VPInstruction::CalculateTripCountMinusVF:
     case VPInstruction::CanonicalIVIncrementForPart:
     case VPInstruction::BranchOnCount:
@@ -1189,6 +1189,35 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
   }
 };
 
+class VPActiveLaneMaskRecipe : public VPRecipeWithIRFlags, public VPValue {
+  const std::string Name;
+  ElementCount MaxLength;
+
+public:
+  VPActiveLaneMaskRecipe(VPValu...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/76514


More information about the llvm-commits mailing list