[llvm] [VPlan] Move predication to VPlanTransform (NFC). (PR #128420)
via llvm-commits
llvm-commits at lists.llvm.org
Tue May 20 07:01:48 PDT 2025
================
@@ -0,0 +1,301 @@
+//===-- VPlanPredicator.cpp - VPlan predicator ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements predication for VPlans.
+///
+//===----------------------------------------------------------------------===//
+
+#include "VPRecipeBuilder.h"
+#include "VPlan.h"
+#include "VPlanCFG.h"
+#include "VPlanTransforms.h"
+#include "VPlanUtils.h"
+#include "llvm/ADT/PostOrderIterator.h"
+
+using namespace llvm;
+
+namespace {
+class VPPredicator {
+ using BlockMaskCacheTy = DenseMap<VPBasicBlock *, VPValue *>;
+ /// Builder to construct recipes to compute masks.
+ VPBuilder Builder;
+
+ /// When we if-convert we need to create edge masks. We have to cache values
+ /// so that we don't end up with exponential recursion/IR.
+ using EdgeMaskCacheTy =
+ DenseMap<std::pair<const VPBasicBlock *, const VPBasicBlock *>,
+ VPValue *>;
+ EdgeMaskCacheTy EdgeMaskCache;
+
+ BlockMaskCacheTy &BlockMaskCache;
+
+ /// Create an edge mask for every destination of cases and/or default.
+ void createSwitchEdgeMasks(VPInstruction *SI);
+
+ /// Computes and return the predicate of the edge between \p Src and \p Dst.
+ VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst);
+
+ /// Returns the *entry* mask for \p VPBB.
+ VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
+ return BlockMaskCache.lookup(VPBB);
+ }
+
+ void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) {
+ // TODO: Include the masks as operands in the predicated VPlan directly to
+ // remove the need to keep a map of masks beyond the predication transform.
+ assert(!getBlockInMask(VPBB) && "Mask already set");
+ BlockMaskCache[VPBB] = Mask;
+ }
+
+ VPValue *setEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst,
+ VPValue *Mask) {
+ assert(!getEdgeMask(Src, Dst) && "Mask already set");
+ return EdgeMaskCache[{Src, Dst}] = Mask;
+ }
+
+public:
+ VPPredicator(BlockMaskCacheTy &BlockMaskCache)
+ : BlockMaskCache(BlockMaskCache) {}
+
+ /// Returns the precomputed predicate of the edge from \p Src to \p Dst.
+ VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const {
+ return EdgeMaskCache.lookup({Src, Dst});
+ }
+
+ /// Compute and return the mask for the vector loop header block.
+ void createHeaderMask(VPBasicBlock *HeaderVPBB, bool FoldTail);
+
+ /// Compute and return the predicate of \p VPBB, assuming that the header
+ /// block of the loop is set to True or the loop mask when tail folding.
+ VPValue *createBlockInMask(VPBasicBlock *VPBB);
+};
+} // namespace
+
+VPValue *VPPredicator::createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst) {
+ assert(is_contained(Dst->getPredecessors(), Src) && "Invalid edge");
+
+ // Look for cached value.
+ VPValue *EdgeMask = getEdgeMask(Src, Dst);
+ if (EdgeMask)
+ return EdgeMask;
+
+ VPValue *SrcMask = getBlockInMask(Src);
+
+ // If there's a single successor, there's no terminator recipe.
+ if (Src->getNumSuccessors() == 1)
+ return setEdgeMask(Src, Dst, SrcMask);
+
+ auto *Term = cast<VPInstruction>(Src->getTerminator());
+ if (Term->getOpcode() == Instruction::Switch) {
+ createSwitchEdgeMasks(Term);
+ return getEdgeMask(Src, Dst);
+ }
+
+ assert(Term->getOpcode() == VPInstruction::BranchOnCond &&
+ "Unsupported terminator");
+ if (Src->getSuccessors()[0] == Src->getSuccessors()[1])
+ return setEdgeMask(Src, Dst, SrcMask);
+
+ EdgeMask = Term->getOperand(0);
+ assert(EdgeMask && "No Edge Mask found for condition");
+
+ if (Src->getSuccessors()[0] != Dst)
+ EdgeMask = Builder.createNot(EdgeMask, Term->getDebugLoc());
+
+ if (SrcMask) { // Otherwise block in-mask is all-one, no need to AND.
+ // The bitwise 'And' of SrcMask and EdgeMask introduces new UB if SrcMask
+ // is false and EdgeMask is poison. Avoid that by using 'LogicalAnd'
+ // instead which generates 'select i1 SrcMask, i1 EdgeMask, i1 false'.
+ EdgeMask = Builder.createLogicalAnd(SrcMask, EdgeMask, Term->getDebugLoc());
+ }
+
+ return setEdgeMask(Src, Dst, EdgeMask);
+}
+
+VPValue *VPPredicator::createBlockInMask(VPBasicBlock *VPBB) {
+ Builder.setInsertPoint(VPBB, VPBB->begin());
+ // All-one mask is modelled as no-mask following the convention for masked
+ // load/store/gather/scatter. Initialize BlockMask to no-mask.
+ VPValue *BlockMask = nullptr;
+ // This is the block mask. We OR all unique incoming edges.
+ for (auto *Predecessor : SetVector<VPBlockBase *>(
+ VPBB->getPredecessors().begin(), VPBB->getPredecessors().end())) {
+ VPValue *EdgeMask = createEdgeMask(cast<VPBasicBlock>(Predecessor), VPBB);
+ if (!EdgeMask) { // Mask of predecessor is all-one so mask of block is
+ // too.
+ setBlockInMask(VPBB, EdgeMask);
+ return EdgeMask;
+ }
+
+ if (!BlockMask) { // BlockMask has its initial nullptr value.
+ BlockMask = EdgeMask;
+ continue;
+ }
+
+ BlockMask = Builder.createOr(BlockMask, EdgeMask, {});
+ }
+
+ setBlockInMask(VPBB, BlockMask);
+ return BlockMask;
+}
+
+void VPPredicator::createHeaderMask(VPBasicBlock *HeaderVPBB, bool FoldTail) {
+ if (!FoldTail) {
+ setBlockInMask(HeaderVPBB, nullptr);
+ return;
+ }
+
+ // Introduce the early-exit compare IV <= BTC to form header block mask.
+ // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by
+ // constructing the desired canonical IV in the header block as its first
+ // non-phi instructions.
+
+ auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi();
+ auto &Plan = *HeaderVPBB->getPlan();
+ auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV());
+ HeaderVPBB->insert(IV, NewInsertionPoint);
----------------
ayalz wrote:
```suggestion
auto &Plan = *HeaderVPBB->getPlan();
auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV());
auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi();
HeaderVPBB->insert(IV, NewInsertionPoint);
```
https://github.com/llvm/llvm-project/pull/128420
More information about the llvm-commits
mailing list