[llvm] [AMDGPU] Add register pressure guard on LLVM-IR level to prevent harmful optimizations (PR #171267)
Carl Ritson via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 9 00:55:44 PST 2025
================
@@ -0,0 +1,474 @@
+//===-- AMDGPURegPressureEstimator.cpp - AMDGPU Reg Pressure -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Estimates VGPR register pressure at IR level for AMDGPURegPressureGuard.
+/// Uses RPO dataflow analysis to track live values through the function.
+/// Results are relative only - not comparable to hardware register counts.
+///
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPURegPressureEstimator.h"
+#include "AMDGPU.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/UniformityAnalysis.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "amdgpu-reg-pressure-estimator"
+
+namespace {
+// Returns VGPR cost in half-registers (16-bit units).
+// Returns 0 for SGPRs, constants, uniform values, and i1 types.
+static unsigned getVgprCost(Value *V, const DataLayout &DL,
+ const UniformityInfo &UA) {
+ if (!V)
+ return 0;
+
+ Type *Ty = V->getType();
+ if (Ty->isVoidTy() || Ty->isTokenTy() || Ty->isMetadataTy() ||
+ !Ty->isSized() || Ty->isIntegerTy(1))
+ return 0;
+
+ if (UA.isUniform(V) || isa<CmpInst>(V))
+ return 0;
+
+ if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+ unsigned AS = PtrTy->getAddressSpace();
+ switch (AS) {
+ case 7:
+ return 2; // offset
+ case 8:
+ return 0;
+ case 9:
+ return 4; // offset + index
+ default:
+ unsigned BitWidth = DL.getPointerSizeInBits(AS);
+ return ((BitWidth + 31) / 32) * 2;
+ }
+ }
+
+ unsigned BitWidth = DL.getTypeStoreSizeInBits(Ty).getFixedValue();
+ if (Ty->isIntegerTy())
+ return ((BitWidth + 31) / 32) * 2;
+
+ // Assuming RealTrue16 on
+ return (BitWidth + 15) / 16;
+}
+
+// Caches block-to-block reachability queries to avoid redundant BFS traversals.
+// Uses RPO indices to quickly reject backward reachability in acyclic regions.
+class ReachabilityCache {
+ DenseMap<std::pair<BasicBlock *, BasicBlock *>, bool> BBCache;
+ PostDominatorTree *PDT;
+
+ DenseMap<BasicBlock *, unsigned> RPOIndex;
+ bool HasBackEdges = false;
+
+public:
+ DominatorTree &DT;
+
+ ReachabilityCache(DominatorTree &DT, PostDominatorTree *PDT)
+ : PDT(PDT), DT(DT) {}
+
+ template <typename RPOTType> void initRPO(RPOTType &RPOT) {
+ unsigned Idx = 0;
+ for (auto *BB : RPOT)
+ RPOIndex[BB] = Idx++;
+
+ for (auto *BB : RPOT) {
+ unsigned FromIdx = RPOIndex[BB];
+ for (BasicBlock *Succ : successors(BB)) {
+ if (RPOIndex[Succ] <= FromIdx) {
+ HasBackEdges = true;
+ return;
+ }
+ }
+ }
+ }
+
+ bool isReachable(Instruction *FromInst, Instruction *ToInst) {
+ BasicBlock *FromBB = FromInst->getParent();
+ BasicBlock *ToBB = ToInst->getParent();
+
+ if (FromBB == ToBB)
+ return FromInst->comesBefore(ToInst);
+
+ auto Key = std::make_pair(FromBB, ToBB);
+ auto It = BBCache.find(Key);
+ if (It != BBCache.end())
+ return It->second;
+
+ auto CacheAndReturn = [&](bool Result) {
+ BBCache[Key] = Result;
+ return Result;
+ };
+
+ if (DT.dominates(ToBB, FromBB))
+ return CacheAndReturn(false);
+
+ if (PDT && PDT->dominates(FromBB, ToBB))
+ return CacheAndReturn(false);
+
+ if (!HasBackEdges && !RPOIndex.empty()) {
+ auto FromIt = RPOIndex.find(FromBB);
+ auto ToIt = RPOIndex.find(ToBB);
+ if (FromIt != RPOIndex.end() && ToIt != RPOIndex.end()) {
+ if (FromIt->second > ToIt->second)
+ return CacheAndReturn(false);
+ }
+ }
+
+ return CacheAndReturn(computeReachability(FromBB, ToBB));
+ }
+
+private:
+ bool computeReachability(BasicBlock *FromBB, BasicBlock *ToBB) {
+ SmallPtrSet<BasicBlock *, 32> Visited;
+ SmallVector<BasicBlock *, 16> Worklist;
+
+ for (BasicBlock *Succ : successors(FromBB)) {
+ if (Succ == ToBB)
+ return true;
+ Worklist.push_back(Succ);
+ Visited.insert(Succ);
+ }
+
+ Visited.insert(FromBB);
+
+ while (!Worklist.empty()) {
+ BasicBlock *BB = Worklist.pop_back_val();
+
+ for (BasicBlock *Succ : successors(BB)) {
+ if (Succ == ToBB)
+ return true;
+
+ if (Visited.count(Succ))
+ continue;
+
+ if (DT.dominates(Succ, FromBB))
+ continue;
+
+ Visited.insert(Succ);
+ Worklist.push_back(Succ);
+ }
+ }
+
+ return false;
+ }
+};
+
+static bool isValueDead(Value *V, Instruction *I, ReachabilityCache &Cache) {
+ for (User *U : V->users()) {
+ Instruction *UseInst = dyn_cast<Instruction>(U);
+ if (!UseInst)
+ continue;
+
+ if (UseInst == I)
+ continue;
+
+ if (Cache.DT.dominates(UseInst, I))
+ continue;
+
+ if (Cache.isReachable(I, UseInst))
+ return false;
+ }
+
+ return true;
+}
+
+// Estimates VGPR register pressure using forward dataflow analysis in RPO.
+// Tracks live value ranges to compute pressure at each program point.
+class AMDGPURegPressureEstimator {
+private:
+ Function &F;
+ DominatorTree &DT;
+ PostDominatorTree *PDT;
+ const UniformityInfo &UA;
+ const DataLayout &DL;
+
+ DenseSet<Value *> GlobalDeadSet;
+ ReachabilityCache ReachCache;
+ unsigned MaxPressureHalfRegs = 0;
+
+public:
+ AMDGPURegPressureEstimator(Function &F, DominatorTree &DT,
+ PostDominatorTree *PDT, const UniformityInfo &UA)
+ : F(F), DT(DT), PDT(PDT), UA(UA), DL(F.getParent()->getDataLayout()),
+ ReachCache(DT, PDT) {}
+
+ unsigned getMaxVGPRs() const { return (MaxPressureHalfRegs + 1) / 2; }
+
+ void analyze() {
+ LLVM_DEBUG(dbgs() << "Analyzing function: " << F.getName() << "\n");
+
+ DenseMap<BasicBlock *, DenseMap<Value *, unsigned>> BlockExitStates;
+
+ ReversePostOrderTraversal<Function *> RPOT(&F);
+ ReachCache.initRPO(RPOT);
+
+ BasicBlock *EntryBB = &F.getEntryBlock();
+ DenseMap<Value *, unsigned> EntryLiveMap;
+ for (Argument &Arg : F.args()) {
+ if (Arg.use_empty())
+ continue;
+ unsigned Cost = getVgprCost(&Arg, DL, UA);
+ if (Cost > 0)
+ EntryLiveMap[&Arg] = Cost;
+ }
+
+ for (auto It = RPOT.begin(), E = RPOT.end(); It != E; ++It) {
+ BasicBlock *BB = *It;
+ DenseMap<Value *, unsigned> BlockEntryLiveMap;
+
+ if (BB == EntryBB)
+ BlockEntryLiveMap = EntryLiveMap;
+ else {
+ for (BasicBlock *Pred : predecessors(BB)) {
+ auto PredIt = BlockExitStates.find(Pred);
+ if (PredIt == BlockExitStates.end())
+ continue;
+
+ const DenseMap<Value *, unsigned> &PredExitMap = PredIt->second;
+ for (auto &[V, Cost] : PredExitMap) {
+ if (GlobalDeadSet.count(V))
+ continue;
+
+ BlockEntryLiveMap[V] = Cost;
+ }
+ }
+ }
+
+ DenseMap<Value *, unsigned> ExitLiveMap =
+ processBlock(*BB, BlockEntryLiveMap);
+
+ BlockExitStates[BB] = ExitLiveMap;
+ }
+
+ LLVM_DEBUG(dbgs() << " Max pressure: " << (MaxPressureHalfRegs / 2)
+ << " VGPRs\n");
+ }
+
+private:
+ static std::string getBBName(const BasicBlock *BB) {
+ if (!BB->getName().empty())
+ return BB->getName().str();
+
+ std::string Name;
+ raw_string_ostream OS(Name);
+ BB->printAsOperand(OS, false);
+ if (!Name.empty() && Name[0] == '%')
+ Name.erase(Name.begin());
+ return Name;
+ }
+
+ DenseMap<Value *, unsigned>
+ processBlock(BasicBlock &BB, DenseMap<Value *, unsigned> InitialLiveMap) {
+ DenseMap<Value *, unsigned> CurrentLiveMap = InitialLiveMap;
+
+ unsigned CurrentPressure = computePressure(CurrentLiveMap);
+
+ if (CurrentPressure > MaxPressureHalfRegs)
+ MaxPressureHalfRegs = CurrentPressure;
+
+ for (Instruction &I : BB) {
+ if (I.isDebugOrPseudoInst())
+ continue;
+
+ if (!I.getType()->isVoidTy() && !I.use_empty()) {
+ if (isa<InsertElementInst>(&I) || isa<InsertValueInst>(&I)) {
+ Value *Aggregate = nullptr;
+ Value *InsertedVal = nullptr;
+
+ if (auto *IEI = dyn_cast<InsertElementInst>(&I)) {
+ Aggregate = IEI->getOperand(0);
+ InsertedVal = IEI->getOperand(1);
+ } else if (auto *IVI = dyn_cast<InsertValueInst>(&I)) {
+ Aggregate = IVI->getAggregateOperand();
+ InsertedVal = IVI->getInsertedValueOperand();
+ }
+
+ unsigned AggCost = CurrentLiveMap.lookup(Aggregate);
+ unsigned InsertedCost = CurrentLiveMap.lookup(InsertedVal);
+ unsigned NewCost = AggCost + InsertedCost;
+
+ if (NewCost > 0) {
+ CurrentLiveMap[&I] = NewCost;
+ CurrentPressure += NewCost;
+ }
+ } else if (isa<ExtractValueInst>(&I) || isa<ExtractElementInst>(&I)) {
+ Value *Source = nullptr;
+ if (auto *EVI = dyn_cast<ExtractValueInst>(&I))
+ Source = EVI->getAggregateOperand();
+ else if (auto *EEI = dyn_cast<ExtractElementInst>(&I))
+ Source = EEI->getVectorOperand();
+
+ bool IsSourceVGPR = Source && CurrentLiveMap.count(Source);
+
+ unsigned ExtractCost = getVgprCost(&I, DL, UA);
+
+ if (ExtractCost > 0 && IsSourceVGPR) {
+ CurrentLiveMap[&I] = ExtractCost;
+ CurrentPressure += ExtractCost;
+
+ auto SourceIt = CurrentLiveMap.find(Source);
+ if (SourceIt != CurrentLiveMap.end()) {
+ unsigned OldCost = SourceIt->second;
+ if (OldCost >= ExtractCost) {
+ SourceIt->second -= ExtractCost;
+ CurrentPressure -= ExtractCost;
+
+ if (SourceIt->second == 0)
+ CurrentLiveMap.erase(SourceIt);
+ }
+ }
+ }
+ } else {
+ bool HasLiveVgprOperand = false;
+ for (Use &Op : I.operands()) {
+ Value *V = Op.get();
+ if (isa<Constant>(V) || isa<BasicBlock>(V))
+ continue;
+ if (CurrentLiveMap.count(V)) {
+ HasLiveVgprOperand = true;
+ break;
+ }
+ }
+
+ if (HasLiveVgprOperand) {
+ unsigned Cost = getVgprCost(&I, DL, UA);
+
+ if (Cost > 0) {
+ CurrentLiveMap[&I] = Cost;
+ CurrentPressure += Cost;
+ }
+ }
+ }
+ }
+
+ for (Use &Op : I.operands()) {
+ Value *V = Op.get();
+ if (isa<Constant>(V) || isa<BasicBlock>(V))
+ continue;
+
+ auto It = CurrentLiveMap.find(V);
+ if (It != CurrentLiveMap.end()) {
+ bool IsDead = isValueDead(V, &I, ReachCache);
+
+ if (IsDead) {
+ unsigned Cost = It->second;
+
+ CurrentLiveMap.erase(It);
+ CurrentPressure -= Cost;
+
+ GlobalDeadSet.insert(V);
+ }
+ }
+ }
+
+ if (CurrentPressure > MaxPressureHalfRegs)
+ MaxPressureHalfRegs = CurrentPressure;
+ }
+
+ return CurrentLiveMap;
+ }
+
+ unsigned computePressure(const DenseMap<Value *, unsigned> &LiveMap) {
+ unsigned Total = 0;
+ for (auto &[V, Cost] : LiveMap)
+ Total += Cost;
+ return Total;
+ }
+};
+
+} // end anonymous namespace
+
+namespace llvm {
+
+unsigned computeMaxVGPRPressure(Function &F, DominatorTree &DT,
+ PostDominatorTree *PDT,
+ const UniformityInfo &UA) {
+ ::AMDGPURegPressureEstimator Estimator(F, DT, PDT, UA);
----------------
perlfu wrote:
Should be able to do without the `::` here.
https://github.com/llvm/llvm-project/pull/171267
More information about the llvm-commits
mailing list