[llvm] [AMDGPU] Add IR LiveReg type-based optimization (PR #66838)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 8 00:10:29 PST 2023
================
@@ -369,9 +450,269 @@ bool AMDGPUCodeGenPrepareImpl::run(Function &F) {
}
}
}
+
+ // GlobalISel should directly use the values, and do not need to emit
+ // CopyTo/CopyFrom Regs across blocks
+ if (UsesGlobalISel)
+ return MadeChange;
+
+ // "Optimize" the virtual regs that cross basic block boundaries. In such
+ // cases, vectors of illegal types will be scalarized and widened, with each
+ // scalar living in its own physical register. The optimization converts the
+ // vectors to equivalent vectors of legal type (which are convereted back
+ // before uses in subsequenmt blocks), to pack the bits into fewer physical
+ // registers (used in CopyToReg/CopyFromReg pairs).
+ LiveRegOptimizer LRO(Mod);
+ for (auto &BB : F) {
+ for (auto &I : BB) {
+ if (!LRO.shouldReplaceUses(I))
+ continue;
+ MadeChange |= LRO.replaceUses(I);
+ }
+ }
+
+ MadeChange |= LRO.replacePHIs();
+ return MadeChange;
+}
+
+bool LiveRegOptimizer::replaceUses(Instruction &I) {
+ bool MadeChange = false;
+
+ struct ConvertUseInfo {
+ Instruction *Converted;
+ SmallVector<Instruction *, 4> Users;
+ };
+ DenseMap<BasicBlock *, ConvertUseInfo> UseConvertTracker;
+
+ LiveRegConversion FromLRC(
+ &I, I.getParent(),
+ static_cast<BasicBlock::iterator>(std::next(I.getIterator())));
+ FromLRC.setNewType(getCompatibleType(FromLRC.getLiveRegDef()));
+ for (auto IUser = I.user_begin(); IUser != I.user_end(); IUser++) {
+
+ if (auto UserInst = dyn_cast<Instruction>(*IUser)) {
+ if (UserInst->getParent() != I.getParent()) {
+ LLVM_DEBUG(dbgs() << *UserInst << "\n\tUses "
+ << *FromLRC.getOriginalType()
+ << " from previous block. Needs conversion\n");
+ convertToOptType(FromLRC);
+ if (!FromLRC.hasConverted())
+ continue;
+ // If it is a PHI node, just create and collect the new operand. We can
+ // only replace the PHI node once we have converted all the operands
+ if (auto PhiInst = dyn_cast<PHINode>(UserInst)) {
+ for (unsigned Idx = 0; Idx < PhiInst->getNumIncomingValues(); Idx++) {
+ auto IncVal = PhiInst->getIncomingValue(Idx);
+ if (&I == dyn_cast<Instruction>(IncVal)) {
+ auto IncBlock = PhiInst->getIncomingBlock(Idx);
+ auto PHIOps = find_if(
+ PHIUpdater,
+ [&UserInst](
+ std::pair<Instruction *,
+ SmallVector<
+ std::pair<Instruction *, BasicBlock *>, 4>>
+ &Entry) { return Entry.first == UserInst; });
+
+ if (PHIOps == PHIUpdater.end())
+ PHIUpdater.push_back(
+ {UserInst, {{*FromLRC.getConverted(), IncBlock}}});
+ else
+ PHIOps->second.push_back({*FromLRC.getConverted(), IncBlock});
+
+ break;
+ }
+ }
+ continue;
+ }
+
+ // Do not create multiple conversion sequences if there are multiple
+ // uses in the same block
+ if (UseConvertTracker.contains(UserInst->getParent())) {
+ UseConvertTracker[UserInst->getParent()].Users.push_back(UserInst);
+ LLVM_DEBUG(dbgs() << "\tUser already has access to converted def\n");
+ continue;
+ }
+
+ LiveRegConversion ToLRC(*FromLRC.getConverted(), I.getType(),
+ UserInst->getParent(),
+ static_cast<BasicBlock::iterator>(
+ UserInst->getParent()->getFirstNonPHIIt()));
+ convertFromOptType(ToLRC);
+ assert(ToLRC.hasConverted());
+ UseConvertTracker[UserInst->getParent()] = {*ToLRC.getConverted(),
+ {UserInst}};
+ }
+ }
+ }
+
+ // Replace uses of with in a separate loop that is not dependent upon the
+ // state of the uses
+ for (auto &Entry : UseConvertTracker) {
+ for (auto &UserInst : Entry.second.Users) {
+ LLVM_DEBUG(dbgs() << *UserInst
+ << "\n\tNow uses: " << *Entry.second.Converted << "\n");
+ UserInst->replaceUsesOfWith(&I, Entry.second.Converted);
+ MadeChange = true;
+ }
+ }
+ return MadeChange;
+}
+
+bool LiveRegOptimizer::replacePHIs() {
+ bool MadeChange = false;
+ for (auto Ele : PHIUpdater) {
+ auto ThePHINode = dyn_cast<PHINode>(Ele.first);
+ assert(ThePHINode);
+ auto NewPHINodeOps = Ele.second;
+ LLVM_DEBUG(dbgs() << "Attempting to replace: " << *ThePHINode << "\n");
+ // If we have conveted all the required operands, then do the replacement
+ if (ThePHINode->getNumIncomingValues() == NewPHINodeOps.size()) {
+ IRBuilder<> Builder(Ele.first);
+ auto NPHI = Builder.CreatePHI(NewPHINodeOps[0].first->getType(),
+ NewPHINodeOps.size());
+ for (auto IncVals : NewPHINodeOps) {
+ NPHI->addIncoming(IncVals.first, IncVals.second);
+ LLVM_DEBUG(dbgs() << " Using: " << *IncVals.first
+ << " For: " << IncVals.second->getName() << "\n");
+ }
+ LLVM_DEBUG(dbgs() << "Sucessfully replaced with " << *NPHI << "\n");
+ LiveRegConversion ToLRC(NPHI, ThePHINode->getType(),
+ ThePHINode->getParent(),
+ static_cast<BasicBlock::iterator>(
+ ThePHINode->getParent()->getFirstNonPHIIt()));
+ convertFromOptType(ToLRC);
+ assert(ToLRC.hasConverted());
+ Ele.first->replaceAllUsesWith(*ToLRC.getConverted());
+ // The old PHI is no longer used
+ ThePHINode->eraseFromParent();
+ MadeChange = true;
+ }
+ }
return MadeChange;
}
+Type *LiveRegOptimizer::getCompatibleType(Instruction *InstToConvert) {
+ auto OriginalType = InstToConvert->getType();
+ assert(OriginalType->getScalarSizeInBits() <=
+ ConvertToScalar->getScalarSizeInBits());
+ auto VTy = dyn_cast<VectorType>(OriginalType);
+ if (!VTy)
+ return ConvertToScalar;
+
+ auto OriginalSize =
+ VTy->getScalarSizeInBits() * VTy->getElementCount().getFixedValue();
+ auto ConvertScalarSize = ConvertToScalar->getScalarSizeInBits();
+ auto ConvertEltCount =
+ (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
+
+ return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
+ llvm::ElementCount::getFixed(ConvertEltCount));
+}
+
+void LiveRegOptimizer::convertToOptType(LiveRegConversion &LR) {
+ if (LR.hasConverted()) {
+ LLVM_DEBUG(dbgs() << "\tAlready has converted def\n");
+ return;
+ }
+
+ auto VTy = dyn_cast<VectorType>(LR.getOriginalType());
+ assert(VTy);
+ auto NewVTy = dyn_cast<VectorType>(LR.getNewType());
+ assert(NewVTy);
+
+ auto V = static_cast<Value *>(LR.getLiveRegDef());
+ auto OriginalSize =
+ VTy->getScalarSizeInBits() * VTy->getElementCount().getFixedValue();
+ auto NewSize =
+ NewVTy->getScalarSizeInBits() * NewVTy->getElementCount().getFixedValue();
+
+ auto &Builder = LR.getConverBuilder();
+
+ // If there is a bitsize match, we can fit the old vector into a new vector of
+ // desired type
+ if (OriginalSize == NewSize) {
+ LR.setConverted(dyn_cast<Instruction>(Builder.CreateBitCast(V, NewVTy)));
+ LLVM_DEBUG(dbgs() << "\tConverted def to "
+ << *(*LR.getConverted())->getType() << "\n");
+ return;
+ }
+
+ // If there is a bitsize mismatch, we must use a wider vector
+ assert(NewSize > OriginalSize);
+ auto ExpandedVecElementCount =
+ llvm::ElementCount::getFixed(NewSize / VTy->getScalarSizeInBits());
+
+ SmallVector<int, 8> ShuffleMask;
+ for (unsigned I = 0; I < VTy->getElementCount().getFixedValue(); I++)
+ ShuffleMask.push_back(I);
+
+ for (uint64_t I = VTy->getElementCount().getFixedValue();
+ I < ExpandedVecElementCount.getFixedValue(); I++)
+ ShuffleMask.push_back(VTy->getElementCount().getFixedValue());
+
+ auto ExpandedVec =
+ dyn_cast<Instruction>(Builder.CreateShuffleVector(V, ShuffleMask));
+ LR.setConverted(
+ dyn_cast<Instruction>(Builder.CreateBitCast(ExpandedVec, NewVTy)));
+ LLVM_DEBUG(dbgs() << "\tConverted def to " << *(*LR.getConverted())->getType()
+ << "\n");
+ return;
+}
+
+void LiveRegOptimizer::convertFromOptType(LiveRegConversion &LRC) {
+ auto VTy = dyn_cast<VectorType>(LRC.getOriginalType());
+ assert(VTy);
+ auto NewVTy = dyn_cast<VectorType>(LRC.getNewType());
+ assert(NewVTy);
+
+ auto V = static_cast<Value *>(LRC.getLiveRegDef());
+ auto OriginalSize =
+ VTy->getScalarSizeInBits() * VTy->getElementCount().getFixedValue();
+ auto NewSize =
+ NewVTy->getScalarSizeInBits() * NewVTy->getElementCount().getFixedValue();
+
+ auto &Builder = LRC.getConverBuilder();
+
+ // If there is a bitsize match, we simply convert back to the original type
+ if (OriginalSize == NewSize) {
+ LRC.setConverted(dyn_cast<Instruction>(Builder.CreateBitCast(V, NewVTy)));
+ LLVM_DEBUG(dbgs() << "\tProduced for user: " << **LRC.getConverted()
+ << "\n");
+ return;
+ }
+
+ // If there is a bitsize mismatch, we have used a wider vector and must strip
+ // the MSBs to convert back to the original type
+ assert(OriginalSize > NewSize);
+ auto ExpandedVecElementCount = llvm::ElementCount::getFixed(
+ OriginalSize / NewVTy->getScalarSizeInBits());
+ auto ExpandedVT = VectorType::get(
+ Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
+ ExpandedVecElementCount);
+ auto Converted = dyn_cast<Instruction>(
+ Builder.CreateBitCast(LRC.getLiveRegDef(), ExpandedVT));
+
+ auto NarrowElementCount = NewVTy->getElementCount().getFixedValue();
+ SmallVector<int, 8> ShuffleMask;
+ for (uint64_t I = 0; I < NarrowElementCount; I++)
----------------
arsenm wrote:
can do this with the constructor
https://github.com/llvm/llvm-project/pull/66838
More information about the llvm-commits
mailing list