[llvm] [AMDGPU] Add IR LiveReg type-based optimization (PR #66838)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 12 05:01:22 PDT 2024


================
@@ -102,14 +144,248 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
   AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
   UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
 
+  // "Optimize" the virtual regs that cross basic block boundaries. When
+  // building the SelectionDAG, vectors of illegal types that cross basic blocks
+  // will be scalarized and widened, with each scalar living in its
+  // own register. To work around this, this optimization converts the
+  // vectors to equivalent vectors of legal type (which are converted back
+  // before uses in subsequent blocks), to pack the bits into fewer physical
+  // registers (used in CopyToReg/CopyFromReg pairs).
+  LiveRegOptimizer LRO(Mod);
+
   bool Changed = false;
+
   for (auto &BB : F)
-    for (Instruction &I : llvm::make_early_inc_range(BB))
+    for (Instruction &I : make_early_inc_range(BB)) {
       Changed |= visit(I);
+      Changed |= LRO.optimizeLiveType(&I);
+    }
 
+  LRO.removeDeadInstrs();
   return Changed;
 }
 
+Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
+  assert(OriginalType->getScalarSizeInBits() <=
+         ConvertToScalar->getScalarSizeInBits());
+
+  FixedVectorType *VTy = dyn_cast<FixedVectorType>(OriginalType);
+  if (!VTy)
+    return nullptr;
+
+  TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
+  TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
+  unsigned ConvertEltCount =
+      (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
+
+  if (OriginalSize <= ConvertScalarSize)
+    return IntegerType::get(Mod->getContext(), ConvertScalarSize);
+
+  return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
+                         ElementCount::getFixed(ConvertEltCount));
+}
+
+Value *LiveRegOptimizer::convertToOptType(Instruction *V,
+                                          BasicBlock::iterator &InsertPt) {
+  VectorType *VTy = cast<VectorType>(V->getType());
+  Type *NewTy = calculateConvertType(V->getType());
+  if (!NewTy)
+    return nullptr;
+
+  TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
+  TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
+
+  IRBuilder<> Builder(V->getParent(), InsertPt);
+  // If there is a bitsize match, we can fit the old vector into a new vector of
+  // desired type.
+  if (OriginalSize == NewSize)
+    return cast<Instruction>(
+        Builder.CreateBitCast(V, NewTy, V->getName() + ".bc"));
+
+  // If there is a bitsize mismatch, we must use a wider vector.
+  assert(NewSize > OriginalSize);
+  ElementCount ExpandedVecElementCount =
+      ElementCount::getFixed(NewSize / VTy->getScalarSizeInBits());
+
+  SmallVector<int, 8> ShuffleMask;
+  for (unsigned I = 0; I < VTy->getElementCount().getFixedValue(); I++)
+    ShuffleMask.push_back(I);
+
+  for (uint64_t I = VTy->getElementCount().getFixedValue();
+       I < ExpandedVecElementCount.getFixedValue(); I++)
+    ShuffleMask.push_back(VTy->getElementCount().getFixedValue());
+
+  Instruction *ExpandedVec =
+      cast<Instruction>(Builder.CreateShuffleVector(V, ShuffleMask));
+  return cast<Instruction>(
+      Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc"));
+}
+
+Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
+                                            BasicBlock::iterator &InsertPt,
+                                            BasicBlock *InsertBB) {
+  VectorType *NewVTy = cast<VectorType>(ConvertType);
+
+  TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
+  TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
+
+  IRBuilder<> Builder(InsertBB, InsertPt);
+  // If there is a bitsize match, we simply convert back to the original type.
+  if (OriginalSize == NewSize) {
+    return cast<Instruction>(
+        Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc"));
+  }
+
+  // If there is a bitsize mismatch, then we must have used a wider value to
+  // hold the bits.
+  assert(OriginalSize > NewSize);
+  // For wide scalars, we can just truncate the value.
+  if (!V->getType()->isVectorTy()) {
+    Instruction *Trunc = cast<Instruction>(
+        Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
+    return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
+  }
+
+  // For wider vectors, we must strip the MSBs to convert back to the original
+  // type.
+  ElementCount ExpandedVecElementCount =
+      ElementCount::getFixed(OriginalSize / NewVTy->getScalarSizeInBits());
+  VectorType *ExpandedVT = VectorType::get(
+      Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
+      ExpandedVecElementCount);
+  Instruction *Converted =
+      cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
+
+  unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
+  SmallVector<int, 8> ShuffleMask(NarrowElementCount);
+  std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
+
+  return cast<Instruction>(Builder.CreateShuffleVector(Converted, ShuffleMask));
+}
+
+bool LiveRegOptimizer::optimizeLiveType(Instruction *I) {
+  SmallVector<Instruction *, 4> Worklist;
+  SmallPtrSet<PHINode *, 4> PhiNodes;
+  SmallPtrSet<Instruction *, 4> Defs;
+  SmallPtrSet<Instruction *, 4> Uses;
+
+  Worklist.push_back(cast<Instruction>(I));
+  while (!Worklist.empty()) {
+    Instruction *II = Worklist.pop_back_val();
+    if (Visited.count(II))
+      continue;
+    Visited.insert(II);
+
+    Type *ITy = II->getType();
+    // Only vectors of illegal type will be scalarized when building the
+    // selection DAG.
+    bool ShouldReplace = ITy->isVectorTy() && ITy->getScalarSizeInBits() < 16 &&
+                         !ITy->getScalarType()->isPointerTy();
+
+    if (!ShouldReplace)
+      continue;
+
+    if (auto *Phi = dyn_cast<PHINode>(II)) {
+      PhiNodes.insert(Phi);
+      // Collect all the incoming values of problematic PHI nodes.
+      for (Value *V : Phi->incoming_values()) {
+        // Repeat the collection process for newly found PHI nodes.
+        if (auto *OpPhi = dyn_cast<PHINode>(V)) {
+          if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
+            Worklist.push_back(OpPhi);
+          continue;
+        }
+
+        auto IncInst = dyn_cast<Instruction>(V);
----------------
arsenm wrote:

auto *

https://github.com/llvm/llvm-project/pull/66838


More information about the llvm-commits mailing list