[llvm] [AMDGPU] Add IR LiveReg type-based optimization (PR #66838)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 12 05:01:22 PDT 2024


================
@@ -102,14 +144,248 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
   AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
   UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
 
+  // "Optimize" the virtual regs that cross basic block boundaries. When
+  // building the SelectionDAG, vectors of illegal types that cross basic blocks
+  // will be scalarized and widened, with each scalar living in its
+  // own register. To work around this, this optimization converts the
+  // vectors to equivalent vectors of legal type (which are converted back
+  // before uses in subsequent blocks), to pack the bits into fewer physical
+  // registers (used in CopyToReg/CopyFromReg pairs).
+  LiveRegOptimizer LRO(Mod);
+
   bool Changed = false;
+
   for (auto &BB : F)
-    for (Instruction &I : llvm::make_early_inc_range(BB))
+    for (Instruction &I : make_early_inc_range(BB)) {
       Changed |= visit(I);
+      Changed |= LRO.optimizeLiveType(&I);
+    }
 
+  LRO.removeDeadInstrs();
   return Changed;
 }
 
+Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
+  assert(OriginalType->getScalarSizeInBits() <=
+         ConvertToScalar->getScalarSizeInBits());
+
+  FixedVectorType *VTy = dyn_cast<FixedVectorType>(OriginalType);
+  if (!VTy)
+    return nullptr;
+
+  TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
+  TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
+  unsigned ConvertEltCount =
+      (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
+
+  if (OriginalSize <= ConvertScalarSize)
+    return IntegerType::get(Mod->getContext(), ConvertScalarSize);
+
+  return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
+                         ElementCount::getFixed(ConvertEltCount));
+}
+
+Value *LiveRegOptimizer::convertToOptType(Instruction *V,
+                                          BasicBlock::iterator &InsertPt) {
+  VectorType *VTy = cast<VectorType>(V->getType());
+  Type *NewTy = calculateConvertType(V->getType());
+  if (!NewTy)
+    return nullptr;
+
+  TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
+  TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
+
+  IRBuilder<> Builder(V->getParent(), InsertPt);
+  // If there is a bitsize match, we can fit the old vector into a new vector of
+  // desired type.
+  if (OriginalSize == NewSize)
+    return cast<Instruction>(
+        Builder.CreateBitCast(V, NewTy, V->getName() + ".bc"));
+
+  // If there is a bitsize mismatch, we must use a wider vector.
+  assert(NewSize > OriginalSize);
+  ElementCount ExpandedVecElementCount =
+      ElementCount::getFixed(NewSize / VTy->getScalarSizeInBits());
+
+  SmallVector<int, 8> ShuffleMask;
+  for (unsigned I = 0; I < VTy->getElementCount().getFixedValue(); I++)
+    ShuffleMask.push_back(I);
+
+  for (uint64_t I = VTy->getElementCount().getFixedValue();
+       I < ExpandedVecElementCount.getFixedValue(); I++)
+    ShuffleMask.push_back(VTy->getElementCount().getFixedValue());
+
+  Instruction *ExpandedVec =
+      cast<Instruction>(Builder.CreateShuffleVector(V, ShuffleMask));
+  return cast<Instruction>(
+      Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc"));
+}
+
+Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
+                                            BasicBlock::iterator &InsertPt,
+                                            BasicBlock *InsertBB) {
+  VectorType *NewVTy = cast<VectorType>(ConvertType);
+
+  TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
+  TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
+
+  IRBuilder<> Builder(InsertBB, InsertPt);
+  // If there is a bitsize match, we simply convert back to the original type.
+  if (OriginalSize == NewSize) {
+    return cast<Instruction>(
+        Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc"));
+  }
+
+  // If there is a bitsize mismatch, then we must have used a wider value to
+  // hold the bits.
+  assert(OriginalSize > NewSize);
+  // For wide scalars, we can just truncate the value.
+  if (!V->getType()->isVectorTy()) {
+    Instruction *Trunc = cast<Instruction>(
+        Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
+    return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
+  }
+
+  // For wider vectors, we must strip the MSBs to convert back to the original
+  // type.
+  ElementCount ExpandedVecElementCount =
+      ElementCount::getFixed(OriginalSize / NewVTy->getScalarSizeInBits());
+  VectorType *ExpandedVT = VectorType::get(
+      Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
+      ExpandedVecElementCount);
+  Instruction *Converted =
+      cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
+
+  unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
+  SmallVector<int, 8> ShuffleMask(NarrowElementCount);
+  std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
+
+  return cast<Instruction>(Builder.CreateShuffleVector(Converted, ShuffleMask));
+}
+
+bool LiveRegOptimizer::optimizeLiveType(Instruction *I) {
+  SmallVector<Instruction *, 4> Worklist;
+  SmallPtrSet<PHINode *, 4> PhiNodes;
+  SmallPtrSet<Instruction *, 4> Defs;
+  SmallPtrSet<Instruction *, 4> Uses;
+
+  Worklist.push_back(cast<Instruction>(I));
+  while (!Worklist.empty()) {
+    Instruction *II = Worklist.pop_back_val();
+    if (Visited.count(II))
+      continue;
+    Visited.insert(II);
+
+    Type *ITy = II->getType();
+    // Only vectors of illegal type will be scalarized when building the
+    // selection DAG.
+    bool ShouldReplace = ITy->isVectorTy() && ITy->getScalarSizeInBits() < 16 &&
+                         !ITy->getScalarType()->isPointerTy();
----------------
arsenm wrote:

Can you move this to a predicate helper function? Also can you express this in terms of TLI legal types? Also skip scalable vectors.

Above you also assume this is an integer type. Just in case those weird float types get added to the IR, should verify this is an integer 

https://github.com/llvm/llvm-project/pull/66838


More information about the llvm-commits mailing list