[llvm] [AMDGPU] Refine AMDGPULateCodeGenPrepare class. NFC. (PR #118792)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 5 04:05:48 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Jay Foad (jayfoad)

<details>
<summary>Changes</summary>

Use references instead of pointers for most state and initialize it all
in the constructor, and similarly for the LiveRegOptimizer class.


---
Full diff: https://github.com/llvm/llvm-project/pull/118792.diff


1 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp (+36-41) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp b/llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp
index 86ed29acb09abc..830b50307f837e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp
@@ -42,25 +42,25 @@ namespace {
 
 class AMDGPULateCodeGenPrepare
     : public InstVisitor<AMDGPULateCodeGenPrepare, bool> {
-  Module *Mod = nullptr;
-  const DataLayout *DL = nullptr;
+  Function &F;
+  const DataLayout &DL;
   const GCNSubtarget &ST;
 
-  AssumptionCache *AC = nullptr;
-  UniformityInfo *UA = nullptr;
+  AssumptionCache *const AC;
+  UniformityInfo &UA;
 
   SmallVector<WeakTrackingVH, 8> DeadInsts;
 
 public:
-  AMDGPULateCodeGenPrepare(Module &M, const GCNSubtarget &ST,
-                           AssumptionCache *AC, UniformityInfo *UA)
-      : Mod(&M), DL(&M.getDataLayout()), ST(ST), AC(AC), UA(UA) {}
-  bool run(Function &F);
+  AMDGPULateCodeGenPrepare(Function &F, const GCNSubtarget &ST,
+                           AssumptionCache *AC, UniformityInfo &UA)
+      : F(F), DL(F.getDataLayout()), ST(ST), AC(AC), UA(UA) {}
+  bool run();
   bool visitInstruction(Instruction &) { return false; }
 
   // Check if the specified value is at least DWORD aligned.
   bool isDWORDAligned(const Value *V) const {
-    KnownBits Known = computeKnownBits(V, *DL, 0, AC);
+    KnownBits Known = computeKnownBits(V, DL, 0, AC);
     return Known.countMinTrailingZeros() >= 2;
   }
 
@@ -72,11 +72,11 @@ using ValueToValueMap = DenseMap<const Value *, Value *>;
 
 class LiveRegOptimizer {
 private:
-  Module *Mod = nullptr;
-  const DataLayout *DL = nullptr;
-  const GCNSubtarget *ST;
+  Module &Mod;
+  const DataLayout &DL;
+  const GCNSubtarget &ST;
   /// The scalar type to convert to
-  Type *ConvertToScalar;
+  Type *const ConvertToScalar;
   /// The set of visited Instructions
   SmallPtrSet<Instruction *, 4> Visited;
   /// Map of Value -> Converted Value
@@ -110,7 +110,7 @@ class LiveRegOptimizer {
     if (!VTy)
       return false;
 
-    const auto *TLI = ST->getTargetLowering();
+    const auto *TLI = ST.getTargetLowering();
 
     Type *EltTy = VTy->getElementType();
     // If the element size is not less than the convert to scalar size, then we
@@ -125,15 +125,14 @@ class LiveRegOptimizer {
     return LK.first != TargetLoweringBase::TypeLegal;
   }
 
-  LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
-    DL = &Mod->getDataLayout();
-    ConvertToScalar = Type::getInt32Ty(Mod->getContext());
-  }
+  LiveRegOptimizer(Module &Mod, const GCNSubtarget &ST)
+      : Mod(Mod), DL(Mod.getDataLayout()), ST(ST),
+        ConvertToScalar(Type::getInt32Ty(Mod.getContext())) {}
 };
 
 } // end anonymous namespace
 
-bool AMDGPULateCodeGenPrepare::run(Function &F) {
+bool AMDGPULateCodeGenPrepare::run() {
   // "Optimize" the virtual regs that cross basic block boundaries. When
   // building the SelectionDAG, vectors of illegal types that cross basic blocks
   // will be scalarized and widened, with each scalar living in its
@@ -141,7 +140,7 @@ bool AMDGPULateCodeGenPrepare::run(Function &F) {
   // vectors to equivalent vectors of legal type (which are converted back
   // before uses in subsequent blocks), to pack the bits into fewer physical
   // registers (used in CopyToReg/CopyFromReg pairs).
-  LiveRegOptimizer LRO(Mod, &ST);
+  LiveRegOptimizer LRO(*F.getParent(), ST);
 
   bool Changed = false;
 
@@ -163,15 +162,15 @@ Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
 
   FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
 
-  TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
-  TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
+  TypeSize OriginalSize = DL.getTypeSizeInBits(VTy);
+  TypeSize ConvertScalarSize = DL.getTypeSizeInBits(ConvertToScalar);
   unsigned ConvertEltCount =
       (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
 
   if (OriginalSize <= ConvertScalarSize)
-    return IntegerType::get(Mod->getContext(), ConvertScalarSize);
+    return IntegerType::get(Mod.getContext(), ConvertScalarSize);
 
-  return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
+  return VectorType::get(Type::getIntNTy(Mod.getContext(), ConvertScalarSize),
                          ConvertEltCount, false);
 }
 
@@ -180,8 +179,8 @@ Value *LiveRegOptimizer::convertToOptType(Instruction *V,
   FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
   Type *NewTy = calculateConvertType(V->getType());
 
-  TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
-  TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
+  TypeSize OriginalSize = DL.getTypeSizeInBits(VTy);
+  TypeSize NewSize = DL.getTypeSizeInBits(NewTy);
 
   IRBuilder<> Builder(V->getParent(), InsertPt);
   // If there is a bitsize match, we can fit the old vector into a new vector of
@@ -210,8 +209,8 @@ Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
                                             BasicBlock *InsertBB) {
   FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
 
-  TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
-  TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
+  TypeSize OriginalSize = DL.getTypeSizeInBits(V->getType());
+  TypeSize NewSize = DL.getTypeSizeInBits(NewVTy);
 
   IRBuilder<> Builder(InsertBB, InsertPt);
   // If there is a bitsize match, we simply convert back to the original type.
@@ -224,14 +223,14 @@ Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
   // For wide scalars, we can just truncate the value.
   if (!V->getType()->isVectorTy()) {
     Instruction *Trunc = cast<Instruction>(
-        Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
+        Builder.CreateTrunc(V, IntegerType::get(Mod.getContext(), NewSize)));
     return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
   }
 
   // For wider vectors, we must strip the MSBs to convert back to the original
   // type.
   VectorType *ExpandedVT = VectorType::get(
-      Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
+      Type::getIntNTy(Mod.getContext(), NewVTy->getScalarSizeInBits()),
       (OriginalSize / NewVTy->getScalarSizeInBits()), false);
   Instruction *Converted =
       cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
@@ -410,15 +409,15 @@ bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
   // Skip aggregate types.
   if (Ty->isAggregateType())
     return false;
-  unsigned TySize = DL->getTypeStoreSize(Ty);
+  unsigned TySize = DL.getTypeStoreSize(Ty);
   // Only handle sub-DWORD loads.
   if (TySize >= 4)
     return false;
   // That load must be at least naturally aligned.
-  if (LI.getAlign() < DL->getABITypeAlign(Ty))
+  if (LI.getAlign() < DL.getABITypeAlign(Ty))
     return false;
   // It should be uniform, i.e. a scalar load.
-  return UA->isUniform(&LI);
+  return UA.isUniform(&LI);
 }
 
 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
@@ -435,7 +434,7 @@ bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
 
   int64_t Offset = 0;
   auto *Base =
-      GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, *DL);
+      GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, DL);
   // If that base is not DWORD aligned, it's not safe to perform the following
   // transforms.
   if (!isDWORDAligned(Base))
@@ -452,7 +451,7 @@ bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
   IRBuilder<> IRB(&LI);
   IRB.SetCurrentDebugLocation(LI.getDebugLoc());
 
-  unsigned LdBits = DL->getTypeStoreSizeInBits(LI.getType());
+  unsigned LdBits = DL.getTypeStoreSizeInBits(LI.getType());
   auto *IntNTy = Type::getIntNTy(LI.getContext(), LdBits);
 
   auto *NewPtr = IRB.CreateConstGEP1_64(
@@ -480,9 +479,7 @@ AMDGPULateCodeGenPreparePass::run(Function &F, FunctionAnalysisManager &FAM) {
   AssumptionCache &AC = FAM.getResult<AssumptionAnalysis>(F);
   UniformityInfo &UI = FAM.getResult<UniformityInfoAnalysis>(F);
 
-  AMDGPULateCodeGenPrepare Impl(*F.getParent(), ST, &AC, &UI);
-
-  bool Changed = Impl.run(F);
+  bool Changed = AMDGPULateCodeGenPrepare(F, ST, &AC, UI).run();
 
   if (!Changed)
     return PreservedAnalyses::all();
@@ -524,9 +521,7 @@ bool AMDGPULateCodeGenPrepareLegacy::runOnFunction(Function &F) {
   UniformityInfo &UI =
       getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
 
-  AMDGPULateCodeGenPrepare Impl(*F.getParent(), ST, &AC, &UI);
-
-  return Impl.run(F);
+  return AMDGPULateCodeGenPrepare(F, ST, &AC, UI).run();
 }
 
 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepareLegacy, DEBUG_TYPE,

``````````

</details>


https://github.com/llvm/llvm-project/pull/118792


More information about the llvm-commits mailing list