[llvm] [VPlan] Add new VPInstruction ocpode for header mask. (PR #89603)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 1 10:09:55 PDT 2024


================
@@ -1203,52 +1164,32 @@ static VPActiveLaneMaskPHIRecipe *addVPLaneMaskPhiAndUpdateExitBranch(
   return LaneMaskPhi;
 }
 
-/// Collect all VPValues representing a header mask through the (ICMP_ULE,
-/// WideCanonicalIV, backedge-taken-count) pattern.
-/// TODO: Introduce explicit recipe for header-mask instead of searching
-/// for the header-mask pattern manually.
-static SmallVector<VPValue *> collectAllHeaderMasks(VPlan &Plan) {
-  SmallVector<VPValue *> WideCanonicalIVs;
-  auto *FoundWidenCanonicalIVUser =
-      find_if(Plan.getCanonicalIV()->users(),
-              [](VPUser *U) { return isa<VPWidenCanonicalIVRecipe>(U); });
-  assert(count_if(Plan.getCanonicalIV()->users(),
-                  [](VPUser *U) { return isa<VPWidenCanonicalIVRecipe>(U); }) <=
-             1 &&
-         "Must have at most one VPWideCanonicalIVRecipe");
-  if (FoundWidenCanonicalIVUser != Plan.getCanonicalIV()->users().end()) {
-    auto *WideCanonicalIV =
-        cast<VPWidenCanonicalIVRecipe>(*FoundWidenCanonicalIVUser);
-    WideCanonicalIVs.push_back(WideCanonicalIV);
-  }
-
-  // Also include VPWidenIntOrFpInductionRecipes that represent a widened
-  // version of the canonical induction.
+/// Return the header mask recipe of the VPlan, if there is one.
+static VPInstruction *getHeaderMask(VPlan &Plan) {
   VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
-  for (VPRecipeBase &Phi : HeaderVPBB->phis()) {
-    auto *WidenOriginalIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi);
-    if (WidenOriginalIV && WidenOriginalIV->isCanonical())
-      WideCanonicalIVs.push_back(WidenOriginalIV);
-  }
+  auto R = find_if(*HeaderVPBB, [](VPRecipeBase &R) {
+    using namespace llvm::VPlanPatternMatch;
+    return match(&R, m_VPInstruction<VPInstruction::HeaderMask>(m_VPValue()));
+  });
+  return R == HeaderVPBB->end() ? nullptr : cast<VPInstruction>(&*R);
+}
 
-  // Walk users of wide canonical IVs and collect to all compares of the form
-  // (ICMP_ULE, WideCanonicalIV, backedge-taken-count).
-  SmallVector<VPValue *> HeaderMasks;
-  VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
-  for (auto *Wide : WideCanonicalIVs) {
-    for (VPUser *U : SmallVector<VPUser *>(Wide->users())) {
-      auto *HeaderMask = dyn_cast<VPInstruction>(U);
-      if (!HeaderMask || HeaderMask->getOpcode() != Instruction::ICmp ||
-          HeaderMask->getPredicate() != CmpInst::ICMP_ULE ||
-          HeaderMask->getOperand(1) != BTC)
-        continue;
+static VPValue *getOrCreateWideCanonicalIV(VPlan &Plan,
+                                           VPRecipeBase *InsertPt) {
 
-      assert(HeaderMask->getOperand(0) == Wide &&
-             "WidenCanonicalIV must be the first operand of the compare");
-      HeaderMasks.push_back(HeaderMask);
-    }
+  VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
+  for (VPRecipeBase &R : HeaderVPBB->phis()) {
+    auto *WideIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&R);
+    if (!WideIV || !WideIV->isCanonical() ||
+        Plan.getCanonicalIV()->getScalarType() != WideIV->getScalarType())
+      continue;
+    return WideIV;
+    break;
----------------
ayalz wrote:

```suggestion
  Type *CanonicalType = Plan.getCanonicalIV()->getScalarType();
  for (VPRecipeBase &R : HeaderVPBB->phis()) {
    auto *WideIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&R);
    if (WideIV && WideIV->isCanonical() &&
         WideIV->getScalarType() == CanonicalType)
      return WideIV;
```

https://github.com/llvm/llvm-project/pull/89603


More information about the llvm-commits mailing list