[llvm] [AMDGPU] Add register pressure guard on LLVM-IR level to prevent harmful optimizations (PR #171267)
Carl Ritson via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 9 00:55:44 PST 2025
================
@@ -0,0 +1,474 @@
+//===-- AMDGPURegPressureEstimator.cpp - AMDGPU Reg Pressure -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Estimates VGPR register pressure at IR level for AMDGPURegPressureGuard.
+/// Uses RPO dataflow analysis to track live values through the function.
+/// Results are relative only - not comparable to hardware register counts.
+///
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPURegPressureEstimator.h"
+#include "AMDGPU.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/UniformityAnalysis.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "amdgpu-reg-pressure-estimator"
+
+namespace {
+// Returns VGPR cost in half-registers (16-bit units).
+// Returns 0 for SGPRs, constants, uniform values, and i1 types.
+static unsigned getVgprCost(Value *V, const DataLayout &DL,
+ const UniformityInfo &UA) {
+ if (!V)
+ return 0;
+
+ Type *Ty = V->getType();
+ if (Ty->isVoidTy() || Ty->isTokenTy() || Ty->isMetadataTy() ||
+ !Ty->isSized() || Ty->isIntegerTy(1))
+ return 0;
+
+ if (UA.isUniform(V) || isa<CmpInst>(V))
+ return 0;
+
+ if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+ unsigned AS = PtrTy->getAddressSpace();
+ switch (AS) {
+ case 7:
+ return 2; // offset
+ case 8:
+ return 0;
+ case 9:
+ return 4; // offset + index
+ default:
+ unsigned BitWidth = DL.getPointerSizeInBits(AS);
+ return ((BitWidth + 31) / 32) * 2;
+ }
+ }
+
+ unsigned BitWidth = DL.getTypeStoreSizeInBits(Ty).getFixedValue();
+ if (Ty->isIntegerTy())
+ return ((BitWidth + 31) / 32) * 2;
+
+ // Assuming RealTrue16 on
+ return (BitWidth + 15) / 16;
+}
+
+// Caches block-to-block reachability queries to avoid redundant BFS traversals.
+// Uses RPO indices to quickly reject backward reachability in acyclic regions.
+class ReachabilityCache {
+ DenseMap<std::pair<BasicBlock *, BasicBlock *>, bool> BBCache;
+ PostDominatorTree *PDT;
+
+ DenseMap<BasicBlock *, unsigned> RPOIndex;
+ bool HasBackEdges = false;
+
+public:
+ DominatorTree &DT;
+
+ ReachabilityCache(DominatorTree &DT, PostDominatorTree *PDT)
+ : PDT(PDT), DT(DT) {}
+
+ template <typename RPOTType> void initRPO(RPOTType &RPOT) {
+ unsigned Idx = 0;
+ for (auto *BB : RPOT)
+ RPOIndex[BB] = Idx++;
+
+ for (auto *BB : RPOT) {
+ unsigned FromIdx = RPOIndex[BB];
+ for (BasicBlock *Succ : successors(BB)) {
+ if (RPOIndex[Succ] <= FromIdx) {
+ HasBackEdges = true;
+ return;
+ }
+ }
+ }
+ }
+
+ bool isReachable(Instruction *FromInst, Instruction *ToInst) {
+ BasicBlock *FromBB = FromInst->getParent();
+ BasicBlock *ToBB = ToInst->getParent();
+
+ if (FromBB == ToBB)
+ return FromInst->comesBefore(ToInst);
+
+ auto Key = std::make_pair(FromBB, ToBB);
+ auto It = BBCache.find(Key);
+ if (It != BBCache.end())
+ return It->second;
+
+ auto CacheAndReturn = [&](bool Result) {
+ BBCache[Key] = Result;
+ return Result;
+ };
+
+ if (DT.dominates(ToBB, FromBB))
+ return CacheAndReturn(false);
+
+ if (PDT && PDT->dominates(FromBB, ToBB))
+ return CacheAndReturn(false);
+
+ if (!HasBackEdges && !RPOIndex.empty()) {
+ auto FromIt = RPOIndex.find(FromBB);
+ auto ToIt = RPOIndex.find(ToBB);
+ if (FromIt != RPOIndex.end() && ToIt != RPOIndex.end()) {
+ if (FromIt->second > ToIt->second)
+ return CacheAndReturn(false);
+ }
+ }
+
+ return CacheAndReturn(computeReachability(FromBB, ToBB));
+ }
+
+private:
+ bool computeReachability(BasicBlock *FromBB, BasicBlock *ToBB) {
+ SmallPtrSet<BasicBlock *, 32> Visited;
+ SmallVector<BasicBlock *, 16> Worklist;
+
+ for (BasicBlock *Succ : successors(FromBB)) {
+ if (Succ == ToBB)
+ return true;
+ Worklist.push_back(Succ);
+ Visited.insert(Succ);
+ }
+
+ Visited.insert(FromBB);
+
+ while (!Worklist.empty()) {
+ BasicBlock *BB = Worklist.pop_back_val();
+
+ for (BasicBlock *Succ : successors(BB)) {
+ if (Succ == ToBB)
+ return true;
+
+ if (Visited.count(Succ))
+ continue;
+
+ if (DT.dominates(Succ, FromBB))
+ continue;
+
+ Visited.insert(Succ);
+ Worklist.push_back(Succ);
+ }
+ }
+
+ return false;
+ }
+};
+
+static bool isValueDead(Value *V, Instruction *I, ReachabilityCache &Cache) {
----------------
perlfu wrote:
Needs a comment to explain relationship of `V` and `I` or better parameter naming.
https://github.com/llvm/llvm-project/pull/171267
More information about the llvm-commits
mailing list