[llvm] [AMDGPU] Add IR LiveReg type-based optimization (PR #66838)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 24 23:56:11 PDT 2024
================
@@ -102,14 +182,266 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
+ // "Optimize" the virtual regs that cross basic block boundaries. In such
+ // cases, vectors of illegal types will be scalarized and widened, with each
+ // scalar living in its own physical register. The optimization converts the
+ // vectors to equivalent vectors of legal type (which are convereted back
+ // before uses in subsequent blocks), to pack the bits into fewer physical
+ // registers (used in CopyToReg/CopyFromReg pairs).
+ LiveRegOptimizer LRO(Mod);
+
bool Changed = false;
for (auto &BB : F)
- for (Instruction &I : llvm::make_early_inc_range(BB))
+ for (Instruction &I : llvm::make_early_inc_range(BB)) {
Changed |= visit(I);
+ if (!LRO.shouldReplaceUses(I))
+ continue;
+ Changed |= LRO.replaceUses(I);
+ }
+ Changed |= LRO.replacePHIs();
return Changed;
}
+bool LiveRegOptimizer::replaceUses(Instruction &I) {
+ bool MadeChange = false;
+
+ struct ConvertUseInfo {
+ Instruction *Converted;
+ SmallVector<Instruction *, 4> Users;
+ };
+ DenseMap<BasicBlock *, ConvertUseInfo> InsertedConversionMap;
+
+ ConversionCandidateInfo FromCCI(
+ &I, I.getParent(),
+ static_cast<BasicBlock::iterator>(std::next(I.getIterator())));
+ FromCCI.setNewType(getCompatibleType(FromCCI.getLiveRegDef()));
+ for (auto IUser = I.user_begin(); IUser != I.user_end(); IUser++) {
+
+ if (Instruction *UserInst = dyn_cast<Instruction>(*IUser)) {
+ if (UserInst->getParent() != I.getParent() || isa<PHINode>(UserInst)) {
+ LLVM_DEBUG(dbgs() << *UserInst << "\n\tUses "
+ << *FromCCI.getOriginalType()
+ << " from previous block. Needs conversion\n");
+ convertToOptType(FromCCI);
+ if (!FromCCI.hasConverted())
+ continue;
+ // If it is a PHI node, just create and collect the new operand. We can
+ // only replace the PHI node once we have converted all the operands
+ if (auto PhiInst = dyn_cast<PHINode>(UserInst)) {
+ for (unsigned Idx = 0; Idx < PhiInst->getNumIncomingValues(); Idx++) {
+ Value *IncVal = PhiInst->getIncomingValue(Idx);
+ if (&I == dyn_cast<Instruction>(IncVal)) {
+ BasicBlock *IncBlock = PhiInst->getIncomingBlock(Idx);
+ auto PHIOps = find_if(
+ PHIUpdater,
+ [&UserInst](
+ std::pair<Instruction *,
+ SmallVector<
+ std::pair<Instruction *, BasicBlock *>, 4>>
+ &Entry) { return Entry.first == UserInst; });
+
+ if (PHIOps == PHIUpdater.end())
+ PHIUpdater.push_back(
+ {UserInst, {{FromCCI.getConverted(), IncBlock}}});
+ else
+ PHIOps->second.push_back({FromCCI.getConverted(), IncBlock});
+
+ break;
+ }
+ }
+ continue;
+ }
+
+ // Do not create multiple conversion sequences if there are multiple
+ // uses in the same block
+ if (InsertedConversionMap.contains(UserInst->getParent())) {
+ InsertedConversionMap[UserInst->getParent()].Users.push_back(
+ UserInst);
+ LLVM_DEBUG(dbgs() << "\tUser already has access to converted def\n");
+ continue;
+ }
+
+ ConversionCandidateInfo ToCCI(
+ FromCCI.getConverted(), I.getType(), UserInst->getParent(),
+ static_cast<BasicBlock::iterator>(
+ UserInst->getParent()->getFirstNonPHIIt()));
+ convertFromOptType(ToCCI);
+ assert(ToCCI.hasConverted());
+ InsertedConversionMap[UserInst->getParent()] = {ToCCI.getConverted(),
+ {UserInst}};
+ }
+ }
+ }
+
+ // Replace uses of with in a separate loop that is not dependent upon the
+ // state of the uses
+ for (auto &Entry : InsertedConversionMap) {
+ for (auto &UserInst : Entry.second.Users) {
+ LLVM_DEBUG(dbgs() << *UserInst
+ << "\n\tNow uses: " << *Entry.second.Converted << "\n");
+ UserInst->replaceUsesOfWith(&I, Entry.second.Converted);
+ MadeChange = true;
+ }
+ }
+ return MadeChange;
+}
+
+bool LiveRegOptimizer::replacePHIs() {
+ bool MadeChange = false;
+ for (auto Ele : PHIUpdater) {
+ auto ThePHINode = cast<PHINode>(Ele.first);
+ auto NewPHINodeOps = Ele.second;
+ LLVM_DEBUG(dbgs() << "Attempting to replace: " << *ThePHINode << "\n");
+ // If we have conveted all the required operands, then do the replacement
+ if (ThePHINode->getNumIncomingValues() == NewPHINodeOps.size()) {
+ IRBuilder<> Builder(Ele.first);
+ auto NPHI = Builder.CreatePHI(NewPHINodeOps[0].first->getType(),
+ NewPHINodeOps.size());
+ for (auto IncVals : NewPHINodeOps) {
+ NPHI->addIncoming(IncVals.first, IncVals.second);
+ LLVM_DEBUG(dbgs() << " Using: " << *IncVals.first
+ << " For: " << IncVals.second->getName() << "\n");
+ }
+ LLVM_DEBUG(dbgs() << "Sucessfully replaced with " << *NPHI << "\n");
+ ConversionCandidateInfo ToCCI(
+ NPHI, ThePHINode->getType(), ThePHINode->getParent(),
+ static_cast<BasicBlock::iterator>(
+ ThePHINode->getParent()->getFirstNonPHIIt()));
+ convertFromOptType(ToCCI);
+ assert(ToCCI.hasConverted());
+ Ele.first->replaceAllUsesWith(ToCCI.getConverted());
+ // The old PHI is no longer used
+ ThePHINode->eraseFromParent();
+ MadeChange = true;
+ }
+ }
+ return MadeChange;
+}
+
+Type *LiveRegOptimizer::getCompatibleType(Instruction *InstToConvert) {
+ Type *OriginalType = InstToConvert->getType();
+ assert(OriginalType->getScalarSizeInBits() <=
+ ConvertToScalar->getScalarSizeInBits());
+ VectorType *VTy = dyn_cast<VectorType>(OriginalType);
+ if (!VTy)
+ return ConvertToScalar;
+
+ unsigned OriginalSize = VTy->getPrimitiveSizeInBits();
+ unsigned ConvertScalarSize = ConvertToScalar->getScalarSizeInBits();
+ unsigned ConvertEltCount =
+ (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
+
+ if (OriginalSize <= ConvertScalarSize)
+ return IntegerType::get(Mod->getContext(), ConvertScalarSize);
+
+ return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
+ llvm::ElementCount::getFixed(ConvertEltCount));
+}
+
+void LiveRegOptimizer::convertToOptType(ConversionCandidateInfo &LR) {
+ if (LR.hasConverted()) {
+ LLVM_DEBUG(dbgs() << "\tAlready has converted def\n");
+ return;
+ }
+
+ VectorType *VTy = cast<VectorType>(LR.getOriginalType());
+ Type *NewTy = LR.getNewType();
+
+ unsigned OriginalSize = VTy->getPrimitiveSizeInBits();
+ unsigned NewSize = NewTy->getPrimitiveSizeInBits();
+
+ auto &Builder = LR.getConvertBuilder();
+ Value *V = static_cast<Value *>(LR.getLiveRegDef());
+ // If there is a bitsize match, we can fit the old vector into a new vector of
+ // desired type
+ if (OriginalSize == NewSize) {
+ LR.setConverted(dyn_cast<Instruction>(Builder.CreateBitCast(V, NewTy)));
+ LLVM_DEBUG(dbgs() << "\tConverted def to " << *LR.getConverted()->getType()
+ << "\n");
+ return;
+ }
+
+ // If there is a bitsize mismatch, we must use a wider vector
+ assert(NewSize > OriginalSize);
+ ElementCount ExpandedVecElementCount =
+ llvm::ElementCount::getFixed(NewSize / VTy->getScalarSizeInBits());
----------------
arsenm wrote:
Don't need llvm::
https://github.com/llvm/llvm-project/pull/66838
More information about the llvm-commits
mailing list