[llvm] [AMDGPU] Add register pressure guard on LLVM-IR level to prevent harmful optimizations (PR #171267)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 9 02:24:07 PST 2025


https://github.com/tianhbai updated https://github.com/llvm/llvm-project/pull/171267

>From 94ddd976f01d6fd8bcb3428a355c7608e0974f87 Mon Sep 17 00:00:00 2001
From: tianhbai <tianhbai at gmail.com>
Date: Tue, 9 Dec 2025 11:17:21 +0800
Subject: [PATCH 1/2] [AMDGPU] Add register pressure guard to prevent harmful
 optimizations

This patch introduces a register pressure-aware guard for the AMDGPU backend
that prevents optimization passes from making transformations that increase
register pressure beyond architecture-specific thresholds, which can lead to
register spilling and performance degradation.

The guard consists of two main components:
1. AMDGPURegPressureEstimator: Estimates register pressure for AMDGPU
2. AMDGPURegPressureGuard: Analyzes transformations and reverts harmful ones

The implementation tracks register pressure before and after transformations,
and reverts changes that exceed configurable thresholds. This optimization
guard is applied to passes like LICM and Sinking in the AMDGPU backend.

Note: The estimator provides conservative estimates intended for comparing
register pressure before and after optimizations, not for precise allocation
decisions. It may overestimate pressure and does not fully account for backend
optimizations like CSE of duplicate extracts or shuffle operations.
---
 llvm/lib/Target/AMDGPU/AMDGPU.h               |  12 +
 .../AMDGPU/AMDGPURegPressureEstimator.cpp     | 474 ++++++++++++++++++
 .../AMDGPU/AMDGPURegPressureEstimator.h       |  92 ++++
 .../Target/AMDGPU/AMDGPURegPressureGuard.cpp  | 301 +++++++++++
 .../Target/AMDGPU/AMDGPURegPressureGuard.h    | 124 +++++
 .../lib/Target/AMDGPU/AMDGPUTargetMachine.cpp |  15 +-
 llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h  |   4 +
 llvm/lib/Target/AMDGPU/CMakeLists.txt         |   2 +
 8 files changed, 1022 insertions(+), 2 deletions(-)
 create mode 100644 llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.cpp
 create mode 100644 llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.h
 create mode 100644 llvm/lib/Target/AMDGPU/AMDGPURegPressureGuard.cpp
 create mode 100644 llvm/lib/Target/AMDGPU/AMDGPURegPressureGuard.h

diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.h b/llvm/lib/Target/AMDGPU/AMDGPU.h
index 5af2a2755cec3..2cc749e7fcb69 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.h
@@ -173,6 +173,18 @@ extern char &AMDGPUReserveWWMRegsLegacyID;
 void initializeAMDGPURewriteOutArgumentsPass(PassRegistry &);
 extern char &AMDGPURewriteOutArgumentsID;
 
+void initializeAMDGPURegPressureEstimatorWrapperPassPass(PassRegistry &);
+extern char &AMDGPURegPressureEstimatorWrapperPassID;
+
+void initializeRegPressureBaselineMeasurementPassPass(PassRegistry &);
+extern char &RegPressureBaselineMeasurementPassID;
+
+void initializeRegPressureVerificationPassPass(PassRegistry &);
+extern char &RegPressureVerificationPassID;
+
+void initializeAMDGPURegPressureGuardLegacyPassPass(PassRegistry &);
+extern char &AMDGPURegPressureGuardLegacyPassID;
+
 void initializeGCNDPPCombineLegacyPass(PassRegistry &);
 extern char &GCNDPPCombineLegacyID;
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.cpp
new file mode 100644
index 0000000000000..b9114673cfc53
--- /dev/null
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.cpp
@@ -0,0 +1,474 @@
+//===-- AMDGPURegPressureEstimator.cpp - AMDGPU Reg Pressure -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Estimates VGPR register pressure at IR level for AMDGPURegPressureGuard.
+/// Uses RPO dataflow analysis to track live values through the function.
+/// Results are relative only - not comparable to hardware register counts.
+///
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPURegPressureEstimator.h"
+#include "AMDGPU.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/UniformityAnalysis.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "amdgpu-reg-pressure-estimator"
+
+namespace {
+// Returns VGPR cost in half-registers (16-bit units).
+// Returns 0 for SGPRs, constants, uniform values, and i1 types.
+static unsigned getVgprCost(Value *V, const DataLayout &DL,
+                            const UniformityInfo &UA) {
+  if (!V)
+    return 0;
+
+  Type *Ty = V->getType();
+  if (Ty->isVoidTy() || Ty->isTokenTy() || Ty->isMetadataTy() ||
+      !Ty->isSized() || Ty->isIntegerTy(1))
+    return 0;
+
+  if (UA.isUniform(V) || isa<CmpInst>(V))
+    return 0;
+
+  if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+    unsigned AS = PtrTy->getAddressSpace();
+    switch (AS) {
+    case 7:
+      return 2; // offset
+    case 8:
+      return 0;
+    case 9:
+      return 4; // offset + index
+    default:
+      unsigned BitWidth = DL.getPointerSizeInBits(AS);
+      return ((BitWidth + 31) / 32) * 2;
+    }
+  }
+
+  unsigned BitWidth = DL.getTypeStoreSizeInBits(Ty).getFixedValue();
+  if (Ty->isIntegerTy())
+    return ((BitWidth + 31) / 32) * 2;
+
+  // Assuming RealTrue16 on
+  return (BitWidth + 15) / 16;
+}
+
+// Caches block-to-block reachability queries to avoid redundant BFS traversals.
+// Uses RPO indices to quickly reject backward reachability in acyclic regions.
+class ReachabilityCache {
+  DenseMap<std::pair<BasicBlock *, BasicBlock *>, bool> BBCache;
+  PostDominatorTree *PDT;
+
+  DenseMap<BasicBlock *, unsigned> RPOIndex;
+  bool HasBackEdges = false;
+
+public:
+  DominatorTree &DT;
+
+  ReachabilityCache(DominatorTree &DT, PostDominatorTree *PDT)
+      : PDT(PDT), DT(DT) {}
+
+  template <typename RPOTType> void initRPO(RPOTType &RPOT) {
+    unsigned Idx = 0;
+    for (auto *BB : RPOT)
+      RPOIndex[BB] = Idx++;
+
+    for (auto *BB : RPOT) {
+      unsigned FromIdx = RPOIndex[BB];
+      for (BasicBlock *Succ : successors(BB)) {
+        if (RPOIndex[Succ] <= FromIdx) {
+          HasBackEdges = true;
+          return;
+        }
+      }
+    }
+  }
+
+  bool isReachable(Instruction *FromInst, Instruction *ToInst) {
+    BasicBlock *FromBB = FromInst->getParent();
+    BasicBlock *ToBB = ToInst->getParent();
+
+    if (FromBB == ToBB)
+      return FromInst->comesBefore(ToInst);
+
+    auto Key = std::make_pair(FromBB, ToBB);
+    auto It = BBCache.find(Key);
+    if (It != BBCache.end())
+      return It->second;
+
+    auto CacheAndReturn = [&](bool Result) {
+      BBCache[Key] = Result;
+      return Result;
+    };
+
+    if (DT.dominates(ToBB, FromBB))
+      return CacheAndReturn(false);
+
+    if (PDT && PDT->dominates(FromBB, ToBB))
+      return CacheAndReturn(false);
+
+    if (!HasBackEdges && !RPOIndex.empty()) {
+      auto FromIt = RPOIndex.find(FromBB);
+      auto ToIt = RPOIndex.find(ToBB);
+      if (FromIt != RPOIndex.end() && ToIt != RPOIndex.end()) {
+        if (FromIt->second > ToIt->second)
+          return CacheAndReturn(false);
+      }
+    }
+
+    return CacheAndReturn(computeReachability(FromBB, ToBB));
+  }
+
+private:
+  bool computeReachability(BasicBlock *FromBB, BasicBlock *ToBB) {
+    SmallPtrSet<BasicBlock *, 32> Visited;
+    SmallVector<BasicBlock *, 16> Worklist;
+
+    for (BasicBlock *Succ : successors(FromBB)) {
+      if (Succ == ToBB)
+        return true;
+      Worklist.push_back(Succ);
+      Visited.insert(Succ);
+    }
+
+    Visited.insert(FromBB);
+
+    while (!Worklist.empty()) {
+      BasicBlock *BB = Worklist.pop_back_val();
+
+      for (BasicBlock *Succ : successors(BB)) {
+        if (Succ == ToBB)
+          return true;
+
+        if (Visited.count(Succ))
+          continue;
+
+        if (DT.dominates(Succ, FromBB))
+          continue;
+
+        Visited.insert(Succ);
+        Worklist.push_back(Succ);
+      }
+    }
+
+    return false;
+  }
+};
+
+static bool isValueDead(Value *V, Instruction *I, ReachabilityCache &Cache) {
+  for (User *U : V->users()) {
+    Instruction *UseInst = dyn_cast<Instruction>(U);
+    if (!UseInst)
+      continue;
+
+    if (UseInst == I)
+      continue;
+
+    if (Cache.DT.dominates(UseInst, I))
+      continue;
+
+    if (Cache.isReachable(I, UseInst))
+      return false;
+  }
+
+  return true;
+}
+
+// Estimates VGPR register pressure using forward dataflow analysis in RPO.
+// Tracks live value ranges to compute pressure at each program point.
+class AMDGPURegPressureEstimator {
+private:
+  Function &F;
+  DominatorTree &DT;
+  PostDominatorTree *PDT;
+  const UniformityInfo &UA;
+  const DataLayout &DL;
+
+  DenseSet<Value *> GlobalDeadSet;
+  ReachabilityCache ReachCache;
+  unsigned MaxPressureHalfRegs = 0;
+
+public:
+  AMDGPURegPressureEstimator(Function &F, DominatorTree &DT,
+                             PostDominatorTree *PDT, const UniformityInfo &UA)
+      : F(F), DT(DT), PDT(PDT), UA(UA), DL(F.getParent()->getDataLayout()),
+        ReachCache(DT, PDT) {}
+
+  unsigned getMaxVGPRs() const { return (MaxPressureHalfRegs + 1) / 2; }
+
+  void analyze() {
+    LLVM_DEBUG(dbgs() << "Analyzing function: " << F.getName() << "\n");
+
+    DenseMap<BasicBlock *, DenseMap<Value *, unsigned>> BlockExitStates;
+
+    ReversePostOrderTraversal<Function *> RPOT(&F);
+    ReachCache.initRPO(RPOT);
+
+    BasicBlock *EntryBB = &F.getEntryBlock();
+    DenseMap<Value *, unsigned> EntryLiveMap;
+    for (Argument &Arg : F.args()) {
+      if (Arg.use_empty())
+        continue;
+      unsigned Cost = getVgprCost(&Arg, DL, UA);
+      if (Cost > 0)
+        EntryLiveMap[&Arg] = Cost;
+    }
+
+    for (auto It = RPOT.begin(), E = RPOT.end(); It != E; ++It) {
+      BasicBlock *BB = *It;
+      DenseMap<Value *, unsigned> BlockEntryLiveMap;
+
+      if (BB == EntryBB)
+        BlockEntryLiveMap = EntryLiveMap;
+      else {
+        for (BasicBlock *Pred : predecessors(BB)) {
+          auto PredIt = BlockExitStates.find(Pred);
+          if (PredIt == BlockExitStates.end())
+            continue;
+
+          const DenseMap<Value *, unsigned> &PredExitMap = PredIt->second;
+          for (auto &[V, Cost] : PredExitMap) {
+            if (GlobalDeadSet.count(V))
+              continue;
+
+            BlockEntryLiveMap[V] = Cost;
+          }
+        }
+      }
+
+      DenseMap<Value *, unsigned> ExitLiveMap =
+          processBlock(*BB, BlockEntryLiveMap);
+
+      BlockExitStates[BB] = ExitLiveMap;
+    }
+
+    LLVM_DEBUG(dbgs() << "  Max pressure: " << (MaxPressureHalfRegs / 2)
+                      << " VGPRs\n");
+  }
+
+private:
+  static std::string getBBName(const BasicBlock *BB) {
+    if (!BB->getName().empty())
+      return BB->getName().str();
+
+    std::string Name;
+    raw_string_ostream OS(Name);
+    BB->printAsOperand(OS, false);
+    if (!Name.empty() && Name[0] == '%')
+      Name.erase(Name.begin());
+    return Name;
+  }
+
+  DenseMap<Value *, unsigned>
+  processBlock(BasicBlock &BB, DenseMap<Value *, unsigned> InitialLiveMap) {
+    DenseMap<Value *, unsigned> CurrentLiveMap = InitialLiveMap;
+
+    unsigned CurrentPressure = computePressure(CurrentLiveMap);
+
+    if (CurrentPressure > MaxPressureHalfRegs)
+      MaxPressureHalfRegs = CurrentPressure;
+
+    for (Instruction &I : BB) {
+      if (I.isDebugOrPseudoInst())
+        continue;
+
+      if (!I.getType()->isVoidTy() && !I.use_empty()) {
+        if (isa<InsertElementInst>(&I) || isa<InsertValueInst>(&I)) {
+          Value *Aggregate = nullptr;
+          Value *InsertedVal = nullptr;
+
+          if (auto *IEI = dyn_cast<InsertElementInst>(&I)) {
+            Aggregate = IEI->getOperand(0);
+            InsertedVal = IEI->getOperand(1);
+          } else if (auto *IVI = dyn_cast<InsertValueInst>(&I)) {
+            Aggregate = IVI->getAggregateOperand();
+            InsertedVal = IVI->getInsertedValueOperand();
+          }
+
+          unsigned AggCost = CurrentLiveMap.lookup(Aggregate);
+          unsigned InsertedCost = CurrentLiveMap.lookup(InsertedVal);
+          unsigned NewCost = AggCost + InsertedCost;
+
+          if (NewCost > 0) {
+            CurrentLiveMap[&I] = NewCost;
+            CurrentPressure += NewCost;
+          }
+        } else if (isa<ExtractValueInst>(&I) || isa<ExtractElementInst>(&I)) {
+          Value *Source = nullptr;
+          if (auto *EVI = dyn_cast<ExtractValueInst>(&I))
+            Source = EVI->getAggregateOperand();
+          else if (auto *EEI = dyn_cast<ExtractElementInst>(&I))
+            Source = EEI->getVectorOperand();
+
+          bool IsSourceVGPR = Source && CurrentLiveMap.count(Source);
+
+          unsigned ExtractCost = getVgprCost(&I, DL, UA);
+
+          if (ExtractCost > 0 && IsSourceVGPR) {
+            CurrentLiveMap[&I] = ExtractCost;
+            CurrentPressure += ExtractCost;
+
+            auto SourceIt = CurrentLiveMap.find(Source);
+            if (SourceIt != CurrentLiveMap.end()) {
+              unsigned OldCost = SourceIt->second;
+              if (OldCost >= ExtractCost) {
+                SourceIt->second -= ExtractCost;
+                CurrentPressure -= ExtractCost;
+
+                if (SourceIt->second == 0)
+                  CurrentLiveMap.erase(SourceIt);
+              }
+            }
+          }
+        } else {
+          bool HasLiveVgprOperand = false;
+          for (Use &Op : I.operands()) {
+            Value *V = Op.get();
+            if (isa<Constant>(V) || isa<BasicBlock>(V))
+              continue;
+            if (CurrentLiveMap.count(V)) {
+              HasLiveVgprOperand = true;
+              break;
+            }
+          }
+
+          if (HasLiveVgprOperand) {
+            unsigned Cost = getVgprCost(&I, DL, UA);
+
+            if (Cost > 0) {
+              CurrentLiveMap[&I] = Cost;
+              CurrentPressure += Cost;
+            }
+          }
+        }
+      }
+
+      for (Use &Op : I.operands()) {
+        Value *V = Op.get();
+        if (isa<Constant>(V) || isa<BasicBlock>(V))
+          continue;
+
+        auto It = CurrentLiveMap.find(V);
+        if (It != CurrentLiveMap.end()) {
+          bool IsDead = isValueDead(V, &I, ReachCache);
+
+          if (IsDead) {
+            unsigned Cost = It->second;
+
+            CurrentLiveMap.erase(It);
+            CurrentPressure -= Cost;
+
+            GlobalDeadSet.insert(V);
+          }
+        }
+      }
+
+      if (CurrentPressure > MaxPressureHalfRegs)
+        MaxPressureHalfRegs = CurrentPressure;
+    }
+
+    return CurrentLiveMap;
+  }
+
+  unsigned computePressure(const DenseMap<Value *, unsigned> &LiveMap) {
+    unsigned Total = 0;
+    for (auto &[V, Cost] : LiveMap)
+      Total += Cost;
+    return Total;
+  }
+};
+
+} // end anonymous namespace
+
+namespace llvm {
+
+unsigned computeMaxVGPRPressure(Function &F, DominatorTree &DT,
+                                PostDominatorTree *PDT,
+                                const UniformityInfo &UA) {
+  ::AMDGPURegPressureEstimator Estimator(F, DT, PDT, UA);
+  Estimator.analyze();
+  return Estimator.getMaxVGPRs();
+}
+
+AnalysisKey AMDGPURegPressureEstimatorAnalysis::Key;
+
+AMDGPURegPressureEstimatorAnalysis::Result
+AMDGPURegPressureEstimatorAnalysis::run(Function &F,
+                                        FunctionAnalysisManager &AM) {
+  auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+  auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
+  auto &UA = AM.getResult<UniformityInfoAnalysis>(F);
+
+  unsigned MaxVGPRs = computeMaxVGPRPressure(F, DT, &PDT, UA);
+  return AMDGPURegPressureEstimatorResult(MaxVGPRs);
+}
+
+PreservedAnalyses
+AMDGPURegPressureEstimatorPrinterPass::run(Function &F,
+                                           FunctionAnalysisManager &AM) {
+  auto Result = AM.getResult<AMDGPURegPressureEstimatorAnalysis>(F);
+  OS << "AMDGPU Register Pressure for function '" << F.getName()
+     << "': " << Result.MaxVGPRs << " VGPRs (IR-level estimate)\n";
+  return PreservedAnalyses::all();
+}
+
+char AMDGPURegPressureEstimatorWrapperPass::ID = 0;
+
+AMDGPURegPressureEstimatorWrapperPass::AMDGPURegPressureEstimatorWrapperPass()
+    : FunctionPass(ID) {
+  initializeAMDGPURegPressureEstimatorWrapperPassPass(
+      *PassRegistry::getPassRegistry());
+}
+
+bool AMDGPURegPressureEstimatorWrapperPass::runOnFunction(Function &F) {
+  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  auto *PDTPass = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
+  PostDominatorTree *PDT = PDTPass ? &PDTPass->getPostDomTree() : nullptr;
+  auto &UA = getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
+
+  MaxVGPRs = computeMaxVGPRPressure(F, DT, PDT, UA);
+  return false;
+}
+
+void AMDGPURegPressureEstimatorWrapperPass::getAnalysisUsage(
+    AnalysisUsage &AU) const {
+  AU.setPreservesAll();
+  AU.addRequired<DominatorTreeWrapperPass>();
+  AU.addRequired<UniformityInfoWrapperPass>();
+  AU.addUsedIfAvailable<PostDominatorTreeWrapperPass>();
+}
+
+void AMDGPURegPressureEstimatorWrapperPass::print(raw_ostream &OS,
+                                                  const Module *) const {
+  OS << "AMDGPU Register Pressure: " << MaxVGPRs
+     << " VGPRs (IR-level estimate)\n";
+}
+
+INITIALIZE_PASS_BEGIN(AMDGPURegPressureEstimatorWrapperPass,
+                      "amdgpu-reg-pressure-estimator",
+                      "AMDGPU Register Pressure Estimator", false, true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
+INITIALIZE_PASS_END(AMDGPURegPressureEstimatorWrapperPass,
+                    "amdgpu-reg-pressure-estimator",
+                    "AMDGPU Register Pressure Estimator", false, true)
+
+} // end namespace llvm
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.h b/llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.h
new file mode 100644
index 0000000000000..da21db4c9085e
--- /dev/null
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.h
@@ -0,0 +1,92 @@
+//===-- AMDGPURegPressureEstimator.h - AMDGPU Reg Pressure -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Estimates VGPR register pressure at IR level for AMDGPU.
+///
+/// Note: This is a conservative estimate intended for comparing register
+/// pressure before and after optimization passes, not for precise register
+/// allocation decisions. The estimator may overestimate pressure, especially
+/// when there are duplicated extractelement and shufflevector operations,
+/// as it does not fully account for optimizations like CSE.
+///
+//===----------------------------------------------------------------------====//
+
+#ifndef LLVM_LIB_TARGET_AMDGPU_AMDGPUREGPRESSUREESTIMATOR_H
+#define LLVM_LIB_TARGET_AMDGPU_AMDGPUREGPRESSUREESTIMATOR_H
+
+#include "llvm/ADT/GenericUniformityInfo.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/SSAContext.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm {
+
+class Function;
+class DominatorTree;
+class PostDominatorTree;
+
+unsigned computeMaxVGPRPressure(Function &F, DominatorTree &DT,
+                                PostDominatorTree *PDT,
+                                const GenericUniformityInfo<SSAContext> &UA);
+
+struct AMDGPURegPressureEstimatorResult {
+  unsigned MaxVGPRs;
+
+  AMDGPURegPressureEstimatorResult() : MaxVGPRs(0) {}
+  explicit AMDGPURegPressureEstimatorResult(unsigned VGPRs) : MaxVGPRs(VGPRs) {}
+
+  bool invalidate(Function &, const PreservedAnalyses &PA,
+                  FunctionAnalysisManager::Invalidator &) {
+    return !(PA.allAnalysesInSetPreserved<CFGAnalyses>() &&
+             PA.allAnalysesInSetPreserved<AllAnalysesOn<Function>>());
+  }
+};
+
+class AMDGPURegPressureEstimatorAnalysis
+    : public AnalysisInfoMixin<AMDGPURegPressureEstimatorAnalysis> {
+  friend AnalysisInfoMixin<AMDGPURegPressureEstimatorAnalysis>;
+  static AnalysisKey Key;
+
+public:
+  using Result = AMDGPURegPressureEstimatorResult;
+
+  Result run(Function &F, FunctionAnalysisManager &AM);
+};
+
+class AMDGPURegPressureEstimatorPrinterPass
+    : public PassInfoMixin<AMDGPURegPressureEstimatorPrinterPass> {
+  raw_ostream &OS;
+
+public:
+  explicit AMDGPURegPressureEstimatorPrinterPass(raw_ostream &OS) : OS(OS) {}
+
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+
+  static bool isRequired() { return true; }
+};
+
+class AMDGPURegPressureEstimatorWrapperPass : public FunctionPass {
+  unsigned MaxVGPRs = 0;
+
+public:
+  static char ID;
+
+  AMDGPURegPressureEstimatorWrapperPass();
+
+  unsigned getMaxVGPRs() const { return MaxVGPRs; }
+
+  bool runOnFunction(Function &F) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+  void print(raw_ostream &OS, const Module *M = nullptr) const override;
+};
+
+} // end namespace llvm
+
+#endif // LLVM_LIB_TARGET_AMDGPU_AMDGPUREGPRESSUREESTIMATOR_H
\ No newline at end of file
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegPressureGuard.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegPressureGuard.cpp
new file mode 100644
index 0000000000000..d8bf242946171
--- /dev/null
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegPressureGuard.cpp
@@ -0,0 +1,301 @@
+//===- AMDGPURegPressureGuard.cpp - Register Pressure Guarded Pass Wrapper ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements a guard mechanism for IR transformations that measures
+/// VGPR register pressure before and after applying a pass, reverting the
+/// transformation if pressure increases beyond a configurable threshold.
+///
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPU.h"
+#include "AMDGPURegPressureGuard.h"
+#include "AMDGPURegPressureEstimator.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/UniformityAnalysis.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include <memory>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "amdgpu-reg-pressure-guard"
+
+STATISTIC(NumFunctionsGuarded, "Number of functions checked by guard");
+STATISTIC(NumTransformationsReverted, "Number of transformations reverted");
+STATISTIC(NumTransformationsKept, "Number of transformations kept");
+
+namespace llvm {
+namespace AMDGPURegPressureGuardHelper {
+
+bool shouldGuardFunction(const AMDGPURegPressureGuardConfig &Config,
+                         Function &F, unsigned BaselineVGPRs) {
+  if (BaselineVGPRs < Config.MinBaselineVGPRs) {
+    LLVM_DEBUG(dbgs() << "AMDGPURegPressureGuard: Skipping " << F.getName()
+                      << " - baseline VGPRs (" << BaselineVGPRs
+                      << ") below threshold (" << Config.MinBaselineVGPRs
+                      << ")\n");
+    return false;
+  }
+  return true;
+}
+
+bool shouldRevert(const AMDGPURegPressureGuardConfig &Config,
+                  unsigned BaselineVGPRs, unsigned NewVGPRs) {
+  if (NewVGPRs <= BaselineVGPRs) {
+    LLVM_DEBUG(dbgs() << "AMDGPURegPressureGuard: Keeping transformation - "
+                      << "VGPRs decreased from " << BaselineVGPRs << " to "
+                      << NewVGPRs << "\n");
+    return false;
+  }
+
+  if (BaselineVGPRs == 0)
+    return false;
+
+  unsigned PercentIncrease = ((NewVGPRs - BaselineVGPRs) * 100) / BaselineVGPRs;
+
+  if (PercentIncrease > Config.MaxPercentIncrease) {
+    LLVM_DEBUG(dbgs() << "AMDGPURegPressureGuard: Reverting transformation - "
+                      << "VGPR increase " << PercentIncrease
+                      << "% exceeds limit " << Config.MaxPercentIncrease
+                      << "%\n");
+    return true;
+  }
+
+  LLVM_DEBUG(dbgs() << "AMDGPURegPressureGuard: Keeping transformation - "
+                    << "VGPR increase " << PercentIncrease
+                    << "% within limit\n");
+  return false;
+}
+
+void restoreFunction(Function &F, Function &BackupFunc) {
+  F.dropAllReferences();
+
+  ValueToValueMapTy VMap;
+  auto *DestI = F.arg_begin();
+  for (Argument &I : BackupFunc.args()) {
+    VMap[&I] = &*DestI++;
+  }
+
+  SmallVector<ReturnInst *, 8> Returns;
+  CloneFunctionInto(&F, &BackupFunc, VMap,
+                    CloneFunctionChangeType::LocalChangesOnly, Returns);
+}
+
+} // namespace AMDGPURegPressureGuardHelper
+} // namespace llvm
+
+namespace {
+
+struct RegPressureGuardState {
+  unsigned BaselineVGPRs = 0;
+  Function *BackupFunc = nullptr;
+  AMDGPURegPressureGuardConfig Config;
+  bool ShouldGuard = false;
+};
+
+static DenseMap<Function *, std::unique_ptr<RegPressureGuardState>>
+    GuardStateMap;
+
+class RegPressureBaselineMeasurementPass : public FunctionPass {
+  AMDGPURegPressureGuardConfig Config;
+
+public:
+  static char ID;
+
+  explicit RegPressureBaselineMeasurementPass(
+      const AMDGPURegPressureGuardConfig &Cfg = {})
+      : FunctionPass(ID), Config(Cfg) {
+    initializeRegPressureBaselineMeasurementPassPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override {
+    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+    auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
+    auto &UA = getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
+
+    unsigned BaselineVGPRs = llvm::computeMaxVGPRPressure(F, DT, &PDT, UA);
+
+    bool ShouldGuard = llvm::AMDGPURegPressureGuardHelper::shouldGuardFunction(
+        Config, F, BaselineVGPRs);
+
+    if (!ShouldGuard) {
+      LLVM_DEBUG(dbgs() << "AMDGPURegPressureGuard: Skipping " << F.getName()
+                        << "\n");
+      return false;
+    }
+
+    ++NumFunctionsGuarded;
+
+    LLVM_DEBUG(dbgs() << "AMDGPURegPressureGuard: Measuring baseline for "
+                      << F.getName() << " (baseline: " << BaselineVGPRs
+                      << " VGPRs)\n");
+
+    auto State = std::make_unique<RegPressureGuardState>();
+    State->BaselineVGPRs = BaselineVGPRs;
+    State->Config = Config;
+    State->ShouldGuard = true;
+
+    ValueToValueMapTy VMap;
+    State->BackupFunc = CloneFunction(&F, VMap);
+
+    GuardStateMap[&F] = std::move(State);
+
+    return false;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addRequired<PostDominatorTreeWrapperPass>();
+    AU.addRequired<UniformityInfoWrapperPass>();
+    AU.setPreservesAll();
+  }
+
+  StringRef getPassName() const override {
+    return "AMDGPU Register Pressure Baseline Measurement";
+  }
+};
+
+class RegPressureVerificationPass : public FunctionPass {
+public:
+  static char ID;
+
+  RegPressureVerificationPass() : FunctionPass(ID) {
+    initializeRegPressureVerificationPassPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override {
+    auto It = GuardStateMap.find(&F);
+    if (It == GuardStateMap.end())
+      return false;
+
+    auto &State = *It->second;
+    if (!State.ShouldGuard) {
+      GuardStateMap.erase(It);
+      return false;
+    }
+
+    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+    auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
+    auto &UA = getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
+
+    unsigned NewVGPRs = llvm::computeMaxVGPRPressure(F, DT, &PDT, UA);
+
+    LLVM_DEBUG(dbgs() << "AMDGPURegPressureGuard: Verifying " << F.getName()
+                      << " (baseline: " << State.BaselineVGPRs
+                      << ", new: " << NewVGPRs << " VGPRs)\n");
+
+    bool ShouldRevert = llvm::AMDGPURegPressureGuardHelper::shouldRevert(
+        State.Config, State.BaselineVGPRs, NewVGPRs);
+
+    if (ShouldRevert) {
+      LLVM_DEBUG(dbgs() << "AMDGPURegPressureGuard: Reverting " << F.getName()
+                        << "\n");
+      llvm::AMDGPURegPressureGuardHelper::restoreFunction(F, *State.BackupFunc);
+      ++NumTransformationsReverted;
+    } else {
+      ++NumTransformationsKept;
+    }
+
+    State.BackupFunc->eraseFromParent();
+    GuardStateMap.erase(It);
+
+    return ShouldRevert;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addRequired<PostDominatorTreeWrapperPass>();
+    AU.addRequired<UniformityInfoWrapperPass>();
+  }
+
+  StringRef getPassName() const override {
+    return "AMDGPU Register Pressure Verification";
+  }
+};
+
+} // end anonymous namespace
+
+char RegPressureBaselineMeasurementPass::ID = 0;
+char RegPressureVerificationPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(RegPressureBaselineMeasurementPass,
+                      "amdgpu-reg-pressure-baseline",
+                      "AMDGPU Register Pressure Baseline Measurement", false,
+                      false)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_END(RegPressureBaselineMeasurementPass,
+                    "amdgpu-reg-pressure-baseline",
+                    "AMDGPU Register Pressure Baseline Measurement", false,
+                    false)
+
+INITIALIZE_PASS_BEGIN(RegPressureVerificationPass,
+                      "amdgpu-reg-pressure-verification",
+                      "AMDGPU Register Pressure Verification", false, false)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_END(RegPressureVerificationPass,
+                    "amdgpu-reg-pressure-verification",
+                    "AMDGPU Register Pressure Verification", false, false)
+
+namespace llvm {
+
+FunctionPass *createRegPressureBaselineMeasurementPass(
+    const AMDGPURegPressureGuardConfig &Config) {
+  return new RegPressureBaselineMeasurementPass(Config);
+}
+
+FunctionPass *createRegPressureVerificationPass() {
+  return new RegPressureVerificationPass();
+}
+
+} // namespace llvm
+
+class AMDGPURegPressureGuardLegacyPass : public FunctionPass {
+public:
+  static char ID;
+
+  AMDGPURegPressureGuardLegacyPass() : FunctionPass(ID) {
+    initializeAMDGPURegPressureGuardLegacyPassPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &) override {
+    llvm_unreachable("Use createRegPressureBaselineMeasurementPass + "
+                     "transformation pass + createRegPressureVerificationPass");
+    return false;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesAll();
+  }
+};
+
+char AMDGPURegPressureGuardLegacyPass::ID = 0;
+
+INITIALIZE_PASS(AMDGPURegPressureGuardLegacyPass,
+                "amdgpu-reg-pressure-guard-legacy",
+                "AMDGPU Register Pressure Guard (Legacy - Deprecated)", false,
+                false)
+
+namespace llvm {
+
+FunctionPass *createAMDGPURegPressureGuardLegacyPass(
+    FunctionPass *WrappedPass, const AMDGPURegPressureGuardConfig &Config) {
+  llvm_unreachable("Use createRegPressureBaselineMeasurementPass + "
+                   "transformation pass + createRegPressureVerificationPass");
+  return nullptr;
+}
+
+} // namespace llvm
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegPressureGuard.h b/llvm/lib/Target/AMDGPU/AMDGPURegPressureGuard.h
new file mode 100644
index 0000000000000..6bad3cb306a00
--- /dev/null
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegPressureGuard.h
@@ -0,0 +1,124 @@
+//===- AMDGPURegPressureGuard.h - Reg Pressure Guard -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Guards transformations by measuring VGPR register pressure using the
+/// AMDGPURegPressureEstimator before and after applying a pass. If the
+/// pressure increases beyond a configurable threshold, the transformation
+/// is reverted to prevent potential register spilling.
+///
+//===----------------------------------------------------------------------====//
+
+#ifndef LLVM_LIB_TARGET_AMDGPU_AMDGPUREGPRESSUREGUARD_H
+#define LLVM_LIB_TARGET_AMDGPU_AMDGPUREGPRESSUREGUARD_H
+
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+class DominatorTree;
+class PostDominatorTree;
+class UniformityInfoAnalysis;
+
+class Function;
+
+struct AMDGPURegPressureGuardConfig {
+  unsigned MaxPercentIncrease = 20;
+  unsigned MinBaselineVGPRs = 96;
+};
+
+FunctionPass *createRegPressureBaselineMeasurementPass(
+    const AMDGPURegPressureGuardConfig &Config);
+
+FunctionPass *createRegPressureVerificationPass();
+
+FunctionPass *createAMDGPURegPressureGuardLegacyPass(
+    FunctionPass *WrappedPass, const AMDGPURegPressureGuardConfig &Config =
+                                   AMDGPURegPressureGuardConfig());
+
+template <typename PassT>
+class AMDGPURegPressureGuardPass
+    : public PassInfoMixin<AMDGPURegPressureGuardPass<PassT>> {
+  PassT WrappedPass;
+  AMDGPURegPressureGuardConfig Config;
+
+public:
+  explicit AMDGPURegPressureGuardPass(
+      PassT Pass, const AMDGPURegPressureGuardConfig &Cfg = {})
+      : WrappedPass(std::move(Pass)), Config(Cfg) {}
+
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+
+  static bool isRequired() { return true; }
+};
+
+} // namespace llvm
+
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/UniformityAnalysis.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+
+namespace llvm {
+
+unsigned computeMaxVGPRPressure(Function &F, DominatorTree &DT,
+                                PostDominatorTree *PDT,
+                                const UniformityInfo &UA);
+
+namespace AMDGPURegPressureGuardHelper {
+bool shouldGuardFunction(const AMDGPURegPressureGuardConfig &Config,
+                         Function &F, unsigned BaselineVGPRs);
+bool shouldRevert(const AMDGPURegPressureGuardConfig &Config,
+                  unsigned BaselineVGPRs, unsigned NewVGPRs);
+void restoreFunction(Function &F, Function &BackupFunc);
+} // namespace AMDGPURegPressureGuardHelper
+
+template <typename PassT>
+PreservedAnalyses
+AMDGPURegPressureGuardPass<PassT>::run(Function &F,
+                                       FunctionAnalysisManager &AM) {
+  using namespace AMDGPURegPressureGuardHelper;
+
+  auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+  auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
+  auto &UA = AM.getResult<UniformityInfoAnalysis>(F);
+
+  unsigned BaselineVGPRs = computeMaxVGPRPressure(F, DT, PDT, UA);
+
+  if (!shouldGuardFunction(Config, F, BaselineVGPRs))
+    return WrappedPass.run(F, AM);
+
+  ValueToValueMapTy VMap;
+  Function *BackupFunc = CloneFunction(&F, VMap);
+
+  PreservedAnalyses PA = WrappedPass.run(F, AM);
+
+  auto &NewDT = AM.getResult<DominatorTreeAnalysis>(F);
+  auto *NewPDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
+  auto &NewUA = AM.getResult<UniformityInfoAnalysis>(F);
+
+  unsigned NewVGPRs = computeMaxVGPRPressure(F, NewDT, NewPDT, NewUA);
+
+  bool ShouldRevert = shouldRevert(Config, BaselineVGPRs, NewVGPRs);
+
+  if (ShouldRevert) {
+    AM.invalidate(F, PreservedAnalyses::none());
+    restoreFunction(F, *BackupFunc);
+    BackupFunc->eraseFromParent();
+    AM.invalidate(F, PreservedAnalyses::none());
+    return PreservedAnalyses::none();
+  }
+
+  BackupFunc->eraseFromParent();
+  return PA;
+}
+
+} // namespace llvm
+
+#endif // LLVM_LIB_TARGET_AMDGPU_AMDGPUREGPRESSUREGUARD_H
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
index e5a35abe6da6b..18fb0a584e8ab 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
@@ -595,6 +595,10 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAMDGPUTarget() {
   initializeAMDGPURewriteAGPRCopyMFMALegacyPass(*PR);
   initializeAMDGPURewriteOutArgumentsPass(*PR);
   initializeAMDGPURewriteUndefForPHILegacyPass(*PR);
+  initializeAMDGPURegPressureEstimatorWrapperPassPass(*PR);
+  initializeRegPressureBaselineMeasurementPassPass(*PR);
+  initializeRegPressureVerificationPassPass(*PR);
+  initializeAMDGPURegPressureGuardLegacyPassPass(*PR);
   initializeSIAnnotateControlFlowLegacyPass(*PR);
   initializeAMDGPUInsertDelayAluLegacyPass(*PR);
   initializeAMDGPULowerVGPREncodingLegacyPass(*PR);
@@ -1276,6 +1280,13 @@ AMDGPUPassConfig::AMDGPUPassConfig(TargetMachine &TM, PassManagerBase &PM)
   disablePass(&ShadowStackGCLoweringID);
 }
 
+void AMDGPUPassConfig::addAMDGPURegPressureGuardedPass(
+    Pass *P, const AMDGPURegPressureGuardConfig &Config) {
+  addPass(createRegPressureBaselineMeasurementPass(Config));
+  addPass(P);
+  addPass(createRegPressureVerificationPass());
+}
+
 void AMDGPUPassConfig::addEarlyCSEOrGVNPass() {
   if (getOptLevel() == CodeGenOptLevel::Aggressive)
     addPass(createGVNPass());
@@ -1382,7 +1393,7 @@ void AMDGPUPassConfig::addIRPasses() {
     // Try to hoist loop invariant parts of divisions AMDGPUCodeGenPrepare may
     // have expanded.
     if (TM.getOptLevel() > CodeGenOptLevel::Less)
-      addPass(createLICMPass());
+      addAMDGPURegPressureGuardedPass(createLICMPass());
   }
 
   TargetPassConfig::addIRPasses();
@@ -1459,7 +1470,7 @@ bool GCNPassConfig::addPreISel() {
   AMDGPUPassConfig::addPreISel();
 
   if (TM->getOptLevel() > CodeGenOptLevel::None)
-    addPass(createSinkingPass());
+    addAMDGPURegPressureGuardedPass(createSinkingPass());
 
   if (TM->getOptLevel() > CodeGenOptLevel::None)
     addPass(createAMDGPULateCodeGenPrepareLegacyPass());
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h
index 06a3047196b8a..4f07a0da39686 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h
@@ -14,6 +14,7 @@
 #ifndef LLVM_LIB_TARGET_AMDGPU_AMDGPUTARGETMACHINE_H
 #define LLVM_LIB_TARGET_AMDGPU_AMDGPUTARGETMACHINE_H
 
+#include "AMDGPURegPressureGuard.h"
 #include "GCNSubtarget.h"
 #include "llvm/CodeGen/CodeGenTargetMachineImpl.h"
 #include "llvm/CodeGen/TargetPassConfig.h"
@@ -141,6 +142,9 @@ class AMDGPUPassConfig : public TargetPassConfig {
   bool addInstSelector() override;
   bool addGCPasses() override;
 
+  void addAMDGPURegPressureGuardedPass(
+      Pass *P, const AMDGPURegPressureGuardConfig &Config = {});
+
   std::unique_ptr<CSEConfigBase> getCSEConfig() const override;
 
   /// Check if a pass is enabled given \p Opt option. The option always
diff --git a/llvm/lib/Target/AMDGPU/CMakeLists.txt b/llvm/lib/Target/AMDGPU/CMakeLists.txt
index 782cbfa76e6e9..f13c7cfaaeec2 100644
--- a/llvm/lib/Target/AMDGPU/CMakeLists.txt
+++ b/llvm/lib/Target/AMDGPU/CMakeLists.txt
@@ -104,6 +104,8 @@ add_llvm_target(AMDGPUCodeGen
   AMDGPURegBankLegalizeRules.cpp
   AMDGPURegBankSelect.cpp
   AMDGPURegisterBankInfo.cpp
+  AMDGPURegPressureEstimator.cpp
+  AMDGPURegPressureGuard.cpp
   AMDGPURemoveIncompatibleFunctions.cpp
   AMDGPUReserveWWMRegs.cpp
   AMDGPUResourceUsageAnalysis.cpp

>From e077ef3c8f168cf4d836bc8b0c9189215577413b Mon Sep 17 00:00:00 2001
From: tianhbai <tianhbai at gmail.com>
Date: Tue, 9 Dec 2025 18:23:48 +0800
Subject: [PATCH 2/2] Address Carl's comments

---
 .../AMDGPU/AMDGPURegPressureEstimator.cpp     | 92 ++++++++++++++-----
 1 file changed, 69 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.cpp
index b9114673cfc53..b10bc4e29a173 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegPressureEstimator.cpp
@@ -19,6 +19,7 @@
 #include "llvm/Analysis/CFG.h"
 #include "llvm/Analysis/PostDominators.h"
 #include "llvm/Analysis/UniformityAnalysis.h"
+#include "llvm/Support/AMDGPUAddrSpace.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
@@ -36,7 +37,7 @@ namespace {
 // Returns VGPR cost in half-registers (16-bit units).
 // Returns 0 for SGPRs, constants, uniform values, and i1 types.
 static unsigned getVgprCost(Value *V, const DataLayout &DL,
-                            const UniformityInfo &UA) {
+                            const UniformityInfo &UA, bool UseRealTrue16) {
   if (!V)
     return 0;
 
@@ -51,11 +52,11 @@ static unsigned getVgprCost(Value *V, const DataLayout &DL,
   if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
     unsigned AS = PtrTy->getAddressSpace();
     switch (AS) {
-    case 7:
+    case AMDGPUAS::BUFFER_FAT_POINTER:
       return 2; // offset
-    case 8:
+    case AMDGPUAS::BUFFER_RESOURCE:
       return 0;
-    case 9:
+    case AMDGPUAS::BUFFER_STRIDED_POINTER:
       return 4; // offset + index
     default:
       unsigned BitWidth = DL.getPointerSizeInBits(AS);
@@ -67,8 +68,9 @@ static unsigned getVgprCost(Value *V, const DataLayout &DL,
   if (Ty->isIntegerTy())
     return ((BitWidth + 31) / 32) * 2;
 
-  // Assuming RealTrue16 on
-  return (BitWidth + 15) / 16;
+  if (UseRealTrue16)
+    return (BitWidth + 15) / 16;
+  return ((BitWidth + 31) / 32) * 2;
 }
 
 // Caches block-to-block reachability queries to avoid redundant BFS traversals.
@@ -86,7 +88,7 @@ class ReachabilityCache {
   ReachabilityCache(DominatorTree &DT, PostDominatorTree *PDT)
       : PDT(PDT), DT(DT) {}
 
-  template <typename RPOTType> void initRPO(RPOTType &RPOT) {
+  void initRPO(ReversePostOrderTraversal<Function *> &RPOT) {
     unsigned Idx = 0;
     for (auto *BB : RPOT)
       RPOIndex[BB] = Idx++;
@@ -173,19 +175,23 @@ class ReachabilityCache {
   }
 };
 
-static bool isValueDead(Value *V, Instruction *I, ReachabilityCache &Cache) {
+// Checks if a value becomes dead after a specific instruction.
+// Returns true if V has no uses reachable from AfterInst, meaning V's
+// live range ends at AfterInst and can be removed from the pressure tracking.
+static bool isValueDead(Value *V, Instruction *AfterInst,
+                        ReachabilityCache &Cache) {
   for (User *U : V->users()) {
     Instruction *UseInst = dyn_cast<Instruction>(U);
     if (!UseInst)
       continue;
 
-    if (UseInst == I)
+    if (UseInst == AfterInst)
       continue;
 
-    if (Cache.DT.dominates(UseInst, I))
+    if (Cache.DT.dominates(UseInst, AfterInst))
       continue;
 
-    if (Cache.isReachable(I, UseInst))
+    if (Cache.isReachable(AfterInst, UseInst))
       return false;
   }
 
@@ -205,18 +211,39 @@ class AMDGPURegPressureEstimator {
   DenseSet<Value *> GlobalDeadSet;
   ReachabilityCache ReachCache;
   unsigned MaxPressureHalfRegs = 0;
+  bool UseRealTrue16;
 
 public:
   AMDGPURegPressureEstimator(Function &F, DominatorTree &DT,
                              PostDominatorTree *PDT, const UniformityInfo &UA)
       : F(F), DT(DT), PDT(PDT), UA(UA), DL(F.getParent()->getDataLayout()),
-        ReachCache(DT, PDT) {}
+        ReachCache(DT, PDT) {
+    // Check if real-true16 feature is enabled for this function.
+    Attribute FSAttr = F.getFnAttribute("target-features");
+    UseRealTrue16 = FSAttr.isValid() &&
+                    FSAttr.getValueAsString().contains("+real-true16");
+  }
 
   unsigned getMaxVGPRs() const { return (MaxPressureHalfRegs + 1) / 2; }
 
   void analyze() {
     LLVM_DEBUG(dbgs() << "Analyzing function: " << F.getName() << "\n");
 
+    // Main algorithm: Forward dataflow analysis in reverse post-order (RPO)
+    //
+    // RPO traversal ensures that:
+    // 1. Each block is visited after all its predecessors (except back edges)
+    // 2. Loop headers are visited before loop bodies
+    // 3. This allows accurate propagation of live value information
+    //
+    // For each basic block, we:
+    // 1. Merge live-in states from all predecessors
+    // 2. Process instructions sequentially, tracking:
+    //    - New values becoming live (instruction results)
+    //    - Values becoming dead (last use detected)
+    //    - Special handling for insert/extract operations
+    // 3. Record live-out state for use by successors
+    // 4. Track maximum pressure seen at any program point
     DenseMap<BasicBlock *, DenseMap<Value *, unsigned>> BlockExitStates;
 
     ReversePostOrderTraversal<Function *> RPOT(&F);
@@ -227,7 +254,7 @@ class AMDGPURegPressureEstimator {
     for (Argument &Arg : F.args()) {
       if (Arg.use_empty())
         continue;
-      unsigned Cost = getVgprCost(&Arg, DL, UA);
+      unsigned Cost = getVgprCost(&Arg, DL, UA, UseRealTrue16);
       if (Cost > 0)
         EntryLiveMap[&Arg] = Cost;
     }
@@ -290,6 +317,24 @@ class AMDGPURegPressureEstimator {
       if (I.isDebugOrPseudoInst())
         continue;
 
+      // Process instruction result: determine if it creates a new live VGPR value
+      //
+      // Three cases with different pressure impacts:
+      //
+      // Case 1: Insert operations (insertelement, insertvalue)
+      //   - Conservative approach: assume both aggregate and inserted value are live
+      //   - Pressure increase = aggregate_cost + inserted_value_cost
+      //   - Rationale: Without dataflow, we can't track which vector lanes are live
+      //
+      // Case 2: Extract operations (extractelement, extractvalue)
+      //   - Extract creates a new live value (extract_cost)
+      //   - Source aggregate cost can be reduced by extract_cost
+      //   - Handles partial liveness of vectors/aggregates
+      //   - Only applies if source is a tracked VGPR value
+      //
+      // Case 3: Other operations
+      //   - Result becomes live only if it has at least one live VGPR operand
+      //   - This prevents counting operations on uniform/constant values
       if (!I.getType()->isVoidTy() && !I.use_empty()) {
         if (isa<InsertElementInst>(&I) || isa<InsertValueInst>(&I)) {
           Value *Aggregate = nullptr;
@@ -320,7 +365,7 @@ class AMDGPURegPressureEstimator {
 
           bool IsSourceVGPR = Source && CurrentLiveMap.count(Source);
 
-          unsigned ExtractCost = getVgprCost(&I, DL, UA);
+          unsigned ExtractCost = getVgprCost(&I, DL, UA, UseRealTrue16);
 
           if (ExtractCost > 0 && IsSourceVGPR) {
             CurrentLiveMap[&I] = ExtractCost;
@@ -351,7 +396,7 @@ class AMDGPURegPressureEstimator {
           }
 
           if (HasLiveVgprOperand) {
-            unsigned Cost = getVgprCost(&I, DL, UA);
+            unsigned Cost = getVgprCost(&I, DL, UA, UseRealTrue16);
 
             if (Cost > 0) {
               CurrentLiveMap[&I] = Cost;
@@ -361,6 +406,7 @@ class AMDGPURegPressureEstimator {
         }
       }
 
+      // Kill dead values: check if any operand's last use is at this instruction
       for (Use &Op : I.operands()) {
         Value *V = Op.get();
         if (isa<Constant>(V) || isa<BasicBlock>(V))
@@ -396,18 +442,18 @@ class AMDGPURegPressureEstimator {
   }
 };
 
-} // end anonymous namespace
-
-namespace llvm {
-
-unsigned computeMaxVGPRPressure(Function &F, DominatorTree &DT,
-                                PostDominatorTree *PDT,
-                                const UniformityInfo &UA) {
-  ::AMDGPURegPressureEstimator Estimator(F, DT, PDT, UA);
+static unsigned computeMaxVGPRPressure(Function &F, DominatorTree &DT,
+                                       PostDominatorTree *PDT,
+                                       const UniformityInfo &UA) {
+  AMDGPURegPressureEstimator Estimator(F, DT, PDT, UA);
   Estimator.analyze();
   return Estimator.getMaxVGPRs();
 }
 
+} // end anonymous namespace
+
+namespace llvm {
+
 AnalysisKey AMDGPURegPressureEstimatorAnalysis::Key;
 
 AMDGPURegPressureEstimatorAnalysis::Result



More information about the llvm-commits mailing list