[llvm] [AMDGPU] Add IR LiveReg type-based optimization (PR #66838)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 21 14:44:59 PDT 2024
================
@@ -102,14 +169,245 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
+ // "Optimize" the virtual regs that cross basic block boundaries. When
+ // building the SelectionDAG, vectors of illegal types that cross basic blocks
+ // will be scalarized and widened, with each scalar living in its
+ // own register. To work around this, this optimization converts the
+ // vectors to equivalent vectors of legal type (which are converted back
+ // before uses in subsequent blocks), to pack the bits into fewer physical
+ // registers (used in CopyToReg/CopyFromReg pairs).
+ LiveRegOptimizer LRO(Mod, &ST);
+
bool Changed = false;
+
for (auto &BB : F)
- for (Instruction &I : llvm::make_early_inc_range(BB))
+ for (Instruction &I : make_early_inc_range(BB)) {
Changed |= visit(I);
+ Changed |= LRO.optimizeLiveType(&I);
+ }
+ LRO.removeDeadInstrs();
return Changed;
}
+Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
+ assert(OriginalType->getScalarSizeInBits() <=
+ ConvertToScalar->getScalarSizeInBits());
+
+ FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
+
+ TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
+ TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
+ unsigned ConvertEltCount =
+ (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
+
+ if (OriginalSize <= ConvertScalarSize)
+ return IntegerType::get(Mod->getContext(), ConvertScalarSize);
+
+ return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
+ ConvertEltCount, false);
+}
+
+Value *LiveRegOptimizer::convertToOptType(Instruction *V,
+ BasicBlock::iterator &InsertPt) {
+ FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
+ Type *NewTy = calculateConvertType(V->getType());
+
+ TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
+ TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
+
+ IRBuilder<> Builder(V->getParent(), InsertPt);
+ // If there is a bitsize match, we can fit the old vector into a new vector of
+ // desired type.
+ if (OriginalSize == NewSize)
+ return cast<Instruction>(
+ Builder.CreateBitCast(V, NewTy, V->getName() + ".bc"));
+
+ // If there is a bitsize mismatch, we must use a wider vector.
+ assert(NewSize > OriginalSize);
+ uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
+
+ SmallVector<int, 8> ShuffleMask;
+ for (unsigned I = 0; I < VTy->getElementCount().getFixedValue(); I++)
+ ShuffleMask.push_back(I);
+
+ for (uint64_t I = VTy->getElementCount().getFixedValue();
+ I < ExpandedVecElementCount; I++)
+ ShuffleMask.push_back(VTy->getElementCount().getFixedValue());
+
+ Instruction *ExpandedVec =
+ cast<Instruction>(Builder.CreateShuffleVector(V, ShuffleMask));
+ return cast<Instruction>(
+ Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc"));
+}
+
+Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
+ BasicBlock::iterator &InsertPt,
+ BasicBlock *InsertBB) {
+ FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
+
+ TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
+ TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
+
+ IRBuilder<> Builder(InsertBB, InsertPt);
+ // If there is a bitsize match, we simply convert back to the original type.
+ if (OriginalSize == NewSize)
+ return cast<Instruction>(
+ Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc"));
+
+ // If there is a bitsize mismatch, then we must have used a wider value to
+ // hold the bits.
+ assert(OriginalSize > NewSize);
+ // For wide scalars, we can just truncate the value.
+ if (!V->getType()->isVectorTy()) {
+ Instruction *Trunc = cast<Instruction>(
+ Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
+ return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
+ }
+
+ // For wider vectors, we must strip the MSBs to convert back to the original
+ // type.
+ VectorType *ExpandedVT = VectorType::get(
+ Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
+ (OriginalSize / NewVTy->getScalarSizeInBits()), false);
+ Instruction *Converted =
+ cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
+
+ unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
+ SmallVector<int, 8> ShuffleMask(NarrowElementCount);
+ std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
+
+ return cast<Instruction>(Builder.CreateShuffleVector(Converted, ShuffleMask));
+}
+
+bool LiveRegOptimizer::optimizeLiveType(Instruction *I) {
+ SmallVector<Instruction *, 4> Worklist;
+ SmallPtrSet<PHINode *, 4> PhiNodes;
+ SmallPtrSet<Instruction *, 4> Defs;
+ SmallPtrSet<Instruction *, 4> Uses;
+
+ Worklist.push_back(cast<Instruction>(I));
+ while (!Worklist.empty()) {
+ Instruction *II = Worklist.pop_back_val();
+
+ if (!Visited.insert(II).second)
+ continue;
+
+ if (!shouldReplace(II->getType()))
+ continue;
+
+ if (PHINode *Phi = dyn_cast<PHINode>(II)) {
+ PhiNodes.insert(Phi);
+ // Collect all the incoming values of problematic PHI nodes.
+ for (Value *V : Phi->incoming_values()) {
+ // Repeat the collection process for newly found PHI nodes.
+ if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
+ if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
+ Worklist.push_back(OpPhi);
+ continue;
+ }
+
+ Instruction *IncInst = dyn_cast<Instruction>(V);
+ // Other incoming value types (e.g. vector literals) are unhandled
+ if (!IncInst && !isa<ConstantAggregateZero>(V))
+ return false;
+
+ // Collect all other incoming values for coercion.
+ if (IncInst)
+ Defs.insert(IncInst);
+ }
+ }
+
+ // Collect all relevant uses.
+ for (User *V : II->users()) {
+ // Repeat the collection process for problematic PHI nodes.
+ if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
+ if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
+ Worklist.push_back(OpPhi);
+ continue;
+ }
+
+ Instruction *UseInst = cast<Instruction>(V);
+ // Collect all uses of PHINodes and any use the crosses BB boundaries.
+ if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) {
+ Uses.insert(UseInst);
+ if (!Defs.count(II) && !isa<PHINode>(II)) {
+ Defs.insert(II);
+ }
+ }
+ }
+ }
+
+ // Coerce and track the defs.
+ for (Instruction *D : Defs) {
+ if (!ValMap.contains(D)) {
+ BasicBlock::iterator InsertPt = std::next(D->getIterator());
+ Value *ConvertVal = convertToOptType(D, InsertPt);
+ assert(ConvertVal);
+ ValMap[D] = ConvertVal;
+ }
+ }
+
+ // Construct new-typed PHI nodes.
+ for (PHINode *Phi : PhiNodes)
+ ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()),
+ Phi->getNumIncomingValues(),
+ Phi->getName() + ".tc", Phi->getIterator());
+
+ // Connect all the PHI nodes with their new incoming values.
+ for (PHINode *Phi : PhiNodes) {
+ PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
+ bool MisingIncVal = false;
+ for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
+ Value *IncVal = Phi->getIncomingValue(I);
+ if (isa<ConstantAggregateZero>(IncVal)) {
+ Type *NewType = calculateConvertType(Phi->getType());
+ NewPhi->addIncoming(ConstantInt::get(NewType, 0, false),
+ Phi->getIncomingBlock(I));
+ } else if (ValMap.contains(IncVal))
+ NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I));
+ else
+ MisingIncVal = true;
+ }
+ if (!MisingIncVal)
+ DeadInstrs.insert(Phi);
+ else
+ DeadInstrs.insert(cast<Instruction>(ValMap[Phi]));
+ Visited.insert(NewPhi);
+ }
+ // Coerce back to the original type and replace the uses.
+ for (Instruction *U : Uses) {
+ // Replace all converted operands for a use.
+ for (auto [OpIdx, Op] : enumerate(U->operands())) {
+ if (ValMap.contains(Op)) {
+ Value *NewVal = nullptr;
+ if (BBUseValMap.contains(U->getParent()) &&
+ BBUseValMap[U->getParent()].contains(ValMap[Op]))
+ NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
+ else {
+ BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
+ NewVal =
+ convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]),
+ InsertPt, U->getParent());
+ BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
+ }
+ assert(NewVal);
+ U->setOperand(OpIdx, NewVal);
----------------
arsenm wrote:
Does this handle the phi with repeated successor case correctly?
https://github.com/llvm/llvm-project/pull/66838
More information about the llvm-commits
mailing list