[llvm] [LV] Introduce the EVLIVSimplify Pass for EVL-vectorized loops (PR #91796)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri May 24 09:12:54 PDT 2024


================
@@ -0,0 +1,296 @@
+//===------ EVLIndVarSimplify.cpp - Optimize vectorized loops w/ EVL IV----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass optimizes a vectorized loop with canonical IV to using EVL-based
+// IV if it was tail-folded by predicated EVL.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/EVLIndVarSimplify.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/IVDescriptors.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Scalar/LoopPassManager.h"
+
+#define DEBUG_TYPE "evl-iv-simplify"
+
+using namespace llvm;
+
+STATISTIC(NumEliminatedCanonicalIV, "Number of canonical IVs we eliminated");
+
+static cl::opt<bool> EnableEVLIndVarSimplify(
+    "enable-evl-indvar-simplify",
+    cl::desc("Enable EVL-based induction variable simplify Pass"), cl::Hidden,
+    cl::init(true));
+
+namespace {
+struct EVLIndVarSimplifyImpl {
+  ScalarEvolution &SE;
+
+  explicit EVLIndVarSimplifyImpl(LoopStandardAnalysisResults &LAR)
+      : SE(LAR.SE) {}
+
+  explicit EVLIndVarSimplifyImpl(ScalarEvolution &SE) : SE(SE) {}
+
+  // Returns true if modify the loop.
+  bool run(Loop &L);
+};
+
+struct EVLIndVarSimplify : public LoopPass {
+  static char ID;
+
+  EVLIndVarSimplify() : LoopPass(ID) {
+    initializeEVLIndVarSimplifyPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnLoop(Loop *L, LPPassManager &LPM) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<ScalarEvolutionWrapperPass>();
+    AU.setPreservesCFG();
+  }
+};
+} // anonymous namespace
+
+static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F) {
+  if (!Step)
+    return 0U;
+
+  // Looking for loops with IV step value in the form of `(<constant VF> x
+  // vscale)`.
+  if (auto *Mul = dyn_cast<SCEVMulExpr>(Step)) {
+    if (Mul->getNumOperands() == 2) {
+      const SCEV *LHS = Mul->getOperand(0);
+      const SCEV *RHS = Mul->getOperand(1);
+      if (auto *Const = dyn_cast<SCEVConstant>(LHS)) {
+        uint64_t V = Const->getAPInt().getLimitedValue();
+        if (isa<SCEVVScale>(RHS) && llvm::isUInt<32>(V))
+          return static_cast<uint32_t>(V);
+      }
+    }
+  }
+
+  // If not, see if the vscale_range of the parent function is a fixed value,
+  // which makes the step value to be replaced by a constant.
+  if (F.hasFnAttribute(Attribute::VScaleRange))
+    if (auto *ConstStep = dyn_cast<SCEVConstant>(Step)) {
+      APInt V = ConstStep->getAPInt().abs();
+      ConstantRange CR = llvm::getVScaleRange(&F, 64);
+      if (const APInt *Fixed = CR.getSingleElement()) {
+        V = V.zextOrTrunc(Fixed->getBitWidth());
+        uint64_t VF = V.udiv(*Fixed).getLimitedValue();
+        if (VF && llvm::isUInt<32>(VF) &&
+            // Make sure step is divisible by vscale.
+            V.urem(*Fixed).isZero())
+          return static_cast<uint32_t>(VF);
+      }
+    }
+
+  return 0U;
+}
+
+// Remove the original induction variable if it's not used anywhere.
+static void tryCleanupOriginalIndVar(PHINode *OrigIndVar,
+                                     const InductionDescriptor &IVD) {
+  if (OrigIndVar->getNumIncomingValues() != 2)
+    return;
+  Value *InitValue = OrigIndVar->getIncomingValue(0);
+  Value *RecValue = OrigIndVar->getIncomingValue(1);
+  if (InitValue != IVD.getStartValue())
+    std::swap(InitValue, RecValue);
+
+  // If the only user of OrigIndVar is the one that produces RecValue, then we
+  // can safely remove it.
+  if (!OrigIndVar->hasOneUse() || OrigIndVar->user_back() != RecValue)
+    return;
+
+  LLVM_DEBUG(dbgs() << "Removed the original IndVar " << *OrigIndVar << "\n");
+  // Turn OrigIndVar into dead code by replacing all its uses by the initial
+  // value of this loop.
+  OrigIndVar->replaceAllUsesWith(InitValue);
+  OrigIndVar->eraseFromParent();
+}
+
+bool EVLIndVarSimplifyImpl::run(Loop &L) {
+  if (!EnableEVLIndVarSimplify)
+    return false;
+
+  InductionDescriptor IVD;
+  PHINode *IndVar = L.getInductionVariable(SE);
+  if (!IndVar || !L.getInductionDescriptor(SE, IVD)) {
+    LLVM_DEBUG(dbgs() << "Cannot retrieve IV from loop " << L.getName()
+                      << "\n");
+    return false;
+  }
+
+  BasicBlock *InitBlock, *BackEdgeBlock;
+  if (!L.getIncomingAndBackEdge(InitBlock, BackEdgeBlock)) {
+    LLVM_DEBUG(dbgs() << "Expect unique incoming and backedge in "
+                      << L.getName() << "\n");
+    return false;
+  }
+
+  // Retrieve the loop bounds.
+  std::optional<Loop::LoopBounds> Bounds = L.getBounds(SE);
+  if (!Bounds) {
+    LLVM_DEBUG(dbgs() << "Could not obtain the bounds for loop " << L.getName()
+                      << "\n");
+    return false;
+  }
+  Value *CanonicalIVInit = &Bounds->getInitialIVValue();
+  Value *CanonicalIVFinal = &Bounds->getFinalIVValue();
+  const SCEV *CanonicalIVInitV = SE.getSCEV(CanonicalIVInit);
+  const SCEV *CanonicalIVFinalV = SE.getSCEV(CanonicalIVFinal);
+
+  const SCEV *StepV = IVD.getStep();
+  uint32_t VF = getVFFromIndVar(StepV, *L.getHeader()->getParent());
+  if (!VF) {
+    LLVM_DEBUG(dbgs() << "Could not infer VF from IndVar step '" << *StepV
+                      << "'\n");
+    return false;
+  }
+  LLVM_DEBUG(dbgs() << "Using VF=" << VF << " for loop " << L.getName()
+                    << "\n");
+
+  // Try to find the EVL-based induction variable.
+  using namespace PatternMatch;
+  BasicBlock *BB = IndVar->getParent();
+
+  Value *EVLIndVar = nullptr;
+  Value *RemTC = nullptr, *TC = nullptr;
+  auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(
+      m_Value(RemTC), m_SpecificInt(VF),
+      /*Scalable=*/m_SpecificInt(1));
+  for (auto &PN : BB->phis()) {
+    if (&PN == IndVar)
+      continue;
+
+    // Check 1: it has to contain both incoming (init) & backedge blocks
+    // from IndVar.
+    if (PN.getBasicBlockIndex(InitBlock) < 0 ||
+        PN.getBasicBlockIndex(BackEdgeBlock) < 0)
+      continue;
+    // Check 2: EVL index is always increasing, thus its inital value has to be
+    // equal to either the initial IV value (when the canonical IV is also
+    // increasing) or the last IV value (when canonical IV is decreasing).
+    Value *Init = PN.getIncomingValueForBlock(InitBlock);
+    using Direction = Loop::LoopBounds::Direction;
+    switch (Bounds->getDirection()) {
+    case Direction::Increasing:
+      if (Init != CanonicalIVInit)
+        continue;
+      break;
+    case Direction::Decreasing:
+      if (Init != CanonicalIVFinal)
+        continue;
+      break;
+    case Direction::Unknown:
+      // To be more permissive and see if either the initial or final IV value
+      // matches PN's init value.
+      if (Init != CanonicalIVInit && Init != CanonicalIVFinal)
+        continue;
+      break;
+    }
+    Value *RecValue = PN.getIncomingValueForBlock(BackEdgeBlock);
+    assert(RecValue);
+
+    LLVM_DEBUG(dbgs() << "Found candidate PN of EVL-based IndVar: " << PN
+                      << "\n");
+
+    // Check 3: Pattern match to find the EVL-based index and total trip count
+    // (TC).
+    if (match(RecValue,
+              m_c_Add(m_ZExtOrSelf(IntrinsicMatch), m_Specific(&PN))) &&
+        match(RemTC, m_Sub(m_Value(TC), m_Specific(&PN)))) {
+      EVLIndVar = RecValue;
+      break;
+    }
+  }
+
+  if (!EVLIndVar || !TC)
+    return false;
+
+  // Make sure TC is related to the original trip count of the canonical IV.
+  // Specifically, if the canonical trip count is derived from TC.
+  const SCEV *TCV = SE.getSCEV(TC);
+  bool MatchTC = false;
+  if (const auto *ConstTCV = dyn_cast<SCEVConstant>(TCV)) {
+    // If TC is a constant and vscale is also a constant, then the canonical
+    // trip count will be constant. Canonical trip count * Step equals to the
+    // round up of TC.
+    if (const auto *ConstStep = dyn_cast<SCEVConstant>(StepV))
+      if (unsigned CanonicalTC = SE.getSmallConstantTripCount(&L)) {
+        APInt Step = ConstStep->getAPInt().abs().zextOrTrunc(64);
+        APInt CanonicalTripCount(64, CanonicalTC);
+        APInt TripCount = ConstTCV->getAPInt().zextOrTrunc(64);
+        MatchTC = (CanonicalTripCount * Step - TripCount).ult(Step);
+      }
+  }
+  // Otherwise, we simply check if the upper or lower bound expression of the
+  // canonical IV contains TC.
+  auto equalsTC = [&](const SCEV *S) -> bool { return S == TCV; };
+  if (!MatchTC && !llvm::SCEVExprContains(CanonicalIVFinalV, equalsTC) &&
+      !llvm::SCEVExprContains(CanonicalIVInitV, equalsTC))
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n");
+
+  // Create an EVL-based comparison and replace the branch to use it as
----------------
preames wrote:

I want to suggest an alternate approach here.  It took me quite a while to try to understand what you're trying to prove here, and I think there's an easier and more general alternative.

First, terminology.  EVLStep = min(VF, sub_nsw(EVL_TC, IV))

Step 1 - Check that the step of the IV being replaced is the same as the EVL IV's step on all but the last iteration.  (You already do this, thought the comments don't explain the significance.)

Step 2 - Using SCEV's backedge taken count reasoning, compute the exit value for the EVL IV.

SCEV won't directly do this for you - that's the whole countability bit - but you should be able to compute it as:
IV->evaluateAtIteration(BTC) + EVLStep.evaluateAtIteration(BTC).  (You could simplify the EVLStep based on the final iteration assumption, but you don't need to, and avoiding it reduces your scope for bugs.)

Note that the fact we know EVLStep <= IV.Step is critical for the correctness of the above.  If we didn't, we'd have to prove non-trivial wrap properties.  

Step 3 - Rewrite the exit condition to be an equality comparison between the EVLIV and the new exit value.  Note that you don't need the equality test restriction from the current code.  An important point is that the exit value above is very likely to be EVL_TC in practice, but that proving that is a proof burden you don't actually need.

Use SCEVExpander for this, and let it worry about all the simplifications.  




https://github.com/llvm/llvm-project/pull/91796


More information about the llvm-commits mailing list