[llvm] [AMDGPU] Add IR LiveReg type-based optimization (PR #66838)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 21 14:44:59 PDT 2024


================
@@ -102,14 +169,245 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
   AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
   UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
 
+  // "Optimize" the virtual regs that cross basic block boundaries. When
+  // building the SelectionDAG, vectors of illegal types that cross basic blocks
+  // will be scalarized and widened, with each scalar living in its
+  // own register. To work around this, this optimization converts the
+  // vectors to equivalent vectors of legal type (which are converted back
+  // before uses in subsequent blocks), to pack the bits into fewer physical
+  // registers (used in CopyToReg/CopyFromReg pairs).
+  LiveRegOptimizer LRO(Mod, &ST);
+
   bool Changed = false;
+
   for (auto &BB : F)
-    for (Instruction &I : llvm::make_early_inc_range(BB))
+    for (Instruction &I : make_early_inc_range(BB)) {
       Changed |= visit(I);
+      Changed |= LRO.optimizeLiveType(&I);
+    }
 
+  LRO.removeDeadInstrs();
   return Changed;
 }
 
+Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
+  assert(OriginalType->getScalarSizeInBits() <=
+         ConvertToScalar->getScalarSizeInBits());
+
+  FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
+
+  TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
+  TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
+  unsigned ConvertEltCount =
+      (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
+
+  if (OriginalSize <= ConvertScalarSize)
+    return IntegerType::get(Mod->getContext(), ConvertScalarSize);
+
+  return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
+                         ConvertEltCount, false);
+}
+
+Value *LiveRegOptimizer::convertToOptType(Instruction *V,
+                                          BasicBlock::iterator &InsertPt) {
+  FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
+  Type *NewTy = calculateConvertType(V->getType());
+
+  TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
+  TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
+
+  IRBuilder<> Builder(V->getParent(), InsertPt);
+  // If there is a bitsize match, we can fit the old vector into a new vector of
+  // desired type.
+  if (OriginalSize == NewSize)
+    return cast<Instruction>(
+        Builder.CreateBitCast(V, NewTy, V->getName() + ".bc"));
+
+  // If there is a bitsize mismatch, we must use a wider vector.
+  assert(NewSize > OriginalSize);
+  uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
+
+  SmallVector<int, 8> ShuffleMask;
+  for (unsigned I = 0; I < VTy->getElementCount().getFixedValue(); I++)
+    ShuffleMask.push_back(I);
+
+  for (uint64_t I = VTy->getElementCount().getFixedValue();
+       I < ExpandedVecElementCount; I++)
+    ShuffleMask.push_back(VTy->getElementCount().getFixedValue());
+
+  Instruction *ExpandedVec =
+      cast<Instruction>(Builder.CreateShuffleVector(V, ShuffleMask));
+  return cast<Instruction>(
+      Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc"));
+}
+
+Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
+                                            BasicBlock::iterator &InsertPt,
+                                            BasicBlock *InsertBB) {
+  FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
+
+  TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
+  TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
+
+  IRBuilder<> Builder(InsertBB, InsertPt);
+  // If there is a bitsize match, we simply convert back to the original type.
+  if (OriginalSize == NewSize)
+    return cast<Instruction>(
+        Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc"));
+
+  // If there is a bitsize mismatch, then we must have used a wider value to
+  // hold the bits.
+  assert(OriginalSize > NewSize);
+  // For wide scalars, we can just truncate the value.
+  if (!V->getType()->isVectorTy()) {
+    Instruction *Trunc = cast<Instruction>(
+        Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
+    return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
+  }
+
+  // For wider vectors, we must strip the MSBs to convert back to the original
+  // type.
+  VectorType *ExpandedVT = VectorType::get(
+      Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
+      (OriginalSize / NewVTy->getScalarSizeInBits()), false);
+  Instruction *Converted =
+      cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
+
+  unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
+  SmallVector<int, 8> ShuffleMask(NarrowElementCount);
+  std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
+
+  return cast<Instruction>(Builder.CreateShuffleVector(Converted, ShuffleMask));
+}
+
+bool LiveRegOptimizer::optimizeLiveType(Instruction *I) {
+  SmallVector<Instruction *, 4> Worklist;
+  SmallPtrSet<PHINode *, 4> PhiNodes;
+  SmallPtrSet<Instruction *, 4> Defs;
+  SmallPtrSet<Instruction *, 4> Uses;
+
+  Worklist.push_back(cast<Instruction>(I));
+  while (!Worklist.empty()) {
+    Instruction *II = Worklist.pop_back_val();
+
+    if (!Visited.insert(II).second)
+      continue;
+
+    if (!shouldReplace(II->getType()))
+      continue;
+
+    if (PHINode *Phi = dyn_cast<PHINode>(II)) {
+      PhiNodes.insert(Phi);
+      // Collect all the incoming values of problematic PHI nodes.
+      for (Value *V : Phi->incoming_values()) {
+        // Repeat the collection process for newly found PHI nodes.
+        if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
+          if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
+            Worklist.push_back(OpPhi);
+          continue;
+        }
+
+        Instruction *IncInst = dyn_cast<Instruction>(V);
+        // Other incoming value types (e.g. vector literals) are unhandled
+        if (!IncInst && !isa<ConstantAggregateZero>(V))
+          return false;
+
+        // Collect all other incoming values for coercion.
+        if (IncInst)
+          Defs.insert(IncInst);
+      }
+    }
+
+    // Collect all relevant uses.
+    for (User *V : II->users()) {
+      // Repeat the collection process for problematic PHI nodes.
+      if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
+        if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
+          Worklist.push_back(OpPhi);
+        continue;
+      }
+
+      Instruction *UseInst = cast<Instruction>(V);
+      // Collect all uses of PHINodes and any use the crosses BB boundaries.
+      if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) {
+        Uses.insert(UseInst);
+        if (!Defs.count(II) && !isa<PHINode>(II)) {
+          Defs.insert(II);
+        }
+      }
+    }
+  }
+
+  // Coerce and track the defs.
+  for (Instruction *D : Defs) {
+    if (!ValMap.contains(D)) {
+      BasicBlock::iterator InsertPt = std::next(D->getIterator());
+      Value *ConvertVal = convertToOptType(D, InsertPt);
+      assert(ConvertVal);
+      ValMap[D] = ConvertVal;
+    }
+  }
+
+  // Construct new-typed PHI nodes.
+  for (PHINode *Phi : PhiNodes)
+    ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()),
+                                  Phi->getNumIncomingValues(),
+                                  Phi->getName() + ".tc", Phi->getIterator());
+
+  // Connect all the PHI nodes with their new incoming values.
+  for (PHINode *Phi : PhiNodes) {
+    PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
+    bool MisingIncVal = false;
+    for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
+      Value *IncVal = Phi->getIncomingValue(I);
+      if (isa<ConstantAggregateZero>(IncVal)) {
+        Type *NewType = calculateConvertType(Phi->getType());
+        NewPhi->addIncoming(ConstantInt::get(NewType, 0, false),
+                            Phi->getIncomingBlock(I));
+      } else if (ValMap.contains(IncVal))
+        NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I));
+      else
+        MisingIncVal = true;
+    }
+    if (!MisingIncVal)
+      DeadInstrs.insert(Phi);
+    else
+      DeadInstrs.insert(cast<Instruction>(ValMap[Phi]));
+    Visited.insert(NewPhi);
+  }
+  // Coerce back to the original type and replace the uses.
+  for (Instruction *U : Uses) {
+    // Replace all converted operands for a use.
+    for (auto [OpIdx, Op] : enumerate(U->operands())) {
+      if (ValMap.contains(Op)) {
+        Value *NewVal = nullptr;
+        if (BBUseValMap.contains(U->getParent()) &&
+            BBUseValMap[U->getParent()].contains(ValMap[Op]))
+          NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
+        else {
+          BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
+          NewVal =
+              convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]),
+                                 InsertPt, U->getParent());
+          BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
+        }
+        assert(NewVal);
+        U->setOperand(OpIdx, NewVal);
----------------
arsenm wrote:

Does this handle the phi with repeated successor case correctly? 

https://github.com/llvm/llvm-project/pull/66838


More information about the llvm-commits mailing list