[llvm] [AMDGPU] Add register pressure guard on LLVM-IR level to prevent harmful optimizations (PR #171267)

Carl Ritson via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 9 00:55:44 PST 2025


================
@@ -0,0 +1,474 @@
+//===-- AMDGPURegPressureEstimator.cpp - AMDGPU Reg Pressure -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Estimates VGPR register pressure at IR level for AMDGPURegPressureGuard.
+/// Uses RPO dataflow analysis to track live values through the function.
+/// Results are relative only - not comparable to hardware register counts.
+///
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPURegPressureEstimator.h"
+#include "AMDGPU.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/UniformityAnalysis.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "amdgpu-reg-pressure-estimator"
+
+namespace {
+// Returns VGPR cost in half-registers (16-bit units).
+// Returns 0 for SGPRs, constants, uniform values, and i1 types.
+static unsigned getVgprCost(Value *V, const DataLayout &DL,
+                            const UniformityInfo &UA) {
+  if (!V)
+    return 0;
+
+  Type *Ty = V->getType();
+  if (Ty->isVoidTy() || Ty->isTokenTy() || Ty->isMetadataTy() ||
+      !Ty->isSized() || Ty->isIntegerTy(1))
+    return 0;
+
+  if (UA.isUniform(V) || isa<CmpInst>(V))
+    return 0;
+
+  if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+    unsigned AS = PtrTy->getAddressSpace();
+    switch (AS) {
+    case 7:
+      return 2; // offset
+    case 8:
+      return 0;
+    case 9:
+      return 4; // offset + index
+    default:
+      unsigned BitWidth = DL.getPointerSizeInBits(AS);
+      return ((BitWidth + 31) / 32) * 2;
+    }
+  }
+
+  unsigned BitWidth = DL.getTypeStoreSizeInBits(Ty).getFixedValue();
+  if (Ty->isIntegerTy())
+    return ((BitWidth + 31) / 32) * 2;
+
+  // Assuming RealTrue16 on
+  return (BitWidth + 15) / 16;
+}
+
+// Caches block-to-block reachability queries to avoid redundant BFS traversals.
+// Uses RPO indices to quickly reject backward reachability in acyclic regions.
+class ReachabilityCache {
+  DenseMap<std::pair<BasicBlock *, BasicBlock *>, bool> BBCache;
+  PostDominatorTree *PDT;
+
+  DenseMap<BasicBlock *, unsigned> RPOIndex;
+  bool HasBackEdges = false;
+
+public:
+  DominatorTree &DT;
+
+  ReachabilityCache(DominatorTree &DT, PostDominatorTree *PDT)
+      : PDT(PDT), DT(DT) {}
+
+  template <typename RPOTType> void initRPO(RPOTType &RPOT) {
+    unsigned Idx = 0;
+    for (auto *BB : RPOT)
+      RPOIndex[BB] = Idx++;
+
+    for (auto *BB : RPOT) {
+      unsigned FromIdx = RPOIndex[BB];
+      for (BasicBlock *Succ : successors(BB)) {
+        if (RPOIndex[Succ] <= FromIdx) {
+          HasBackEdges = true;
+          return;
+        }
+      }
+    }
+  }
+
+  bool isReachable(Instruction *FromInst, Instruction *ToInst) {
+    BasicBlock *FromBB = FromInst->getParent();
+    BasicBlock *ToBB = ToInst->getParent();
+
+    if (FromBB == ToBB)
+      return FromInst->comesBefore(ToInst);
+
+    auto Key = std::make_pair(FromBB, ToBB);
+    auto It = BBCache.find(Key);
+    if (It != BBCache.end())
+      return It->second;
+
+    auto CacheAndReturn = [&](bool Result) {
+      BBCache[Key] = Result;
+      return Result;
+    };
+
+    if (DT.dominates(ToBB, FromBB))
+      return CacheAndReturn(false);
+
+    if (PDT && PDT->dominates(FromBB, ToBB))
+      return CacheAndReturn(false);
+
+    if (!HasBackEdges && !RPOIndex.empty()) {
+      auto FromIt = RPOIndex.find(FromBB);
+      auto ToIt = RPOIndex.find(ToBB);
+      if (FromIt != RPOIndex.end() && ToIt != RPOIndex.end()) {
+        if (FromIt->second > ToIt->second)
+          return CacheAndReturn(false);
+      }
+    }
+
+    return CacheAndReturn(computeReachability(FromBB, ToBB));
+  }
+
+private:
+  bool computeReachability(BasicBlock *FromBB, BasicBlock *ToBB) {
+    SmallPtrSet<BasicBlock *, 32> Visited;
+    SmallVector<BasicBlock *, 16> Worklist;
+
+    for (BasicBlock *Succ : successors(FromBB)) {
+      if (Succ == ToBB)
+        return true;
+      Worklist.push_back(Succ);
+      Visited.insert(Succ);
+    }
+
+    Visited.insert(FromBB);
+
+    while (!Worklist.empty()) {
+      BasicBlock *BB = Worklist.pop_back_val();
+
+      for (BasicBlock *Succ : successors(BB)) {
+        if (Succ == ToBB)
+          return true;
+
+        if (Visited.count(Succ))
+          continue;
+
+        if (DT.dominates(Succ, FromBB))
+          continue;
+
+        Visited.insert(Succ);
+        Worklist.push_back(Succ);
+      }
+    }
+
+    return false;
+  }
+};
+
+static bool isValueDead(Value *V, Instruction *I, ReachabilityCache &Cache) {
----------------
perlfu wrote:

Needs a comment to explain relationship of `V` and `I` or better parameter naming.

https://github.com/llvm/llvm-project/pull/171267


More information about the llvm-commits mailing list