[llvm] [AMDGPU] Add IR LiveReg type-based optimization (PR #66838)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 8 00:10:29 PST 2023


================
@@ -369,9 +450,269 @@ bool AMDGPUCodeGenPrepareImpl::run(Function &F) {
       }
     }
   }
+
+  // GlobalISel should directly use the values, and do not need to emit
+  // CopyTo/CopyFrom Regs across blocks
+  if (UsesGlobalISel)
+    return MadeChange;
+
+  // "Optimize" the virtual regs that cross basic block boundaries. In such
+  // cases, vectors of illegal types will be scalarized and widened, with each
+  // scalar living in its own physical register. The optimization converts the
+  // vectors to equivalent vectors of legal type (which are convereted back
+  // before uses in subsequenmt blocks), to pack the bits into fewer physical
+  // registers (used in CopyToReg/CopyFromReg pairs).
+  LiveRegOptimizer LRO(Mod);
+  for (auto &BB : F) {
+    for (auto &I : BB) {
+      if (!LRO.shouldReplaceUses(I))
+        continue;
+      MadeChange |= LRO.replaceUses(I);
+    }
+  }
+
+  MadeChange |= LRO.replacePHIs();
+  return MadeChange;
+}
+
+bool LiveRegOptimizer::replaceUses(Instruction &I) {
+  bool MadeChange = false;
+
+  struct ConvertUseInfo {
+    Instruction *Converted;
+    SmallVector<Instruction *, 4> Users;
+  };
+  DenseMap<BasicBlock *, ConvertUseInfo> UseConvertTracker;
+
+  LiveRegConversion FromLRC(
+      &I, I.getParent(),
+      static_cast<BasicBlock::iterator>(std::next(I.getIterator())));
+  FromLRC.setNewType(getCompatibleType(FromLRC.getLiveRegDef()));
+  for (auto IUser = I.user_begin(); IUser != I.user_end(); IUser++) {
+
+    if (auto UserInst = dyn_cast<Instruction>(*IUser)) {
+      if (UserInst->getParent() != I.getParent()) {
+        LLVM_DEBUG(dbgs() << *UserInst << "\n\tUses "
+                          << *FromLRC.getOriginalType()
+                          << " from previous block. Needs conversion\n");
+        convertToOptType(FromLRC);
+        if (!FromLRC.hasConverted())
+          continue;
+        // If it is a PHI node, just create and collect the new operand. We can
+        // only replace the PHI node once we have converted all the operands
+        if (auto PhiInst = dyn_cast<PHINode>(UserInst)) {
+          for (unsigned Idx = 0; Idx < PhiInst->getNumIncomingValues(); Idx++) {
+            auto IncVal = PhiInst->getIncomingValue(Idx);
+            if (&I == dyn_cast<Instruction>(IncVal)) {
+              auto IncBlock = PhiInst->getIncomingBlock(Idx);
+              auto PHIOps = find_if(
+                  PHIUpdater,
+                  [&UserInst](
+                      std::pair<Instruction *,
+                                SmallVector<
+                                    std::pair<Instruction *, BasicBlock *>, 4>>
+                          &Entry) { return Entry.first == UserInst; });
+
+              if (PHIOps == PHIUpdater.end())
+                PHIUpdater.push_back(
+                    {UserInst, {{*FromLRC.getConverted(), IncBlock}}});
+              else
+                PHIOps->second.push_back({*FromLRC.getConverted(), IncBlock});
+
+              break;
+            }
+          }
+          continue;
+        }
+
+        // Do not create multiple conversion sequences if there are multiple
+        // uses in the same block
+        if (UseConvertTracker.contains(UserInst->getParent())) {
+          UseConvertTracker[UserInst->getParent()].Users.push_back(UserInst);
+          LLVM_DEBUG(dbgs() << "\tUser already has access to converted def\n");
+          continue;
+        }
+
+        LiveRegConversion ToLRC(*FromLRC.getConverted(), I.getType(),
+                                UserInst->getParent(),
+                                static_cast<BasicBlock::iterator>(
+                                    UserInst->getParent()->getFirstNonPHIIt()));
+        convertFromOptType(ToLRC);
+        assert(ToLRC.hasConverted());
+        UseConvertTracker[UserInst->getParent()] = {*ToLRC.getConverted(),
+                                                    {UserInst}};
+      }
+    }
+  }
+
+  // Replace uses of with in a separate loop that is not dependent upon the
+  // state of the uses
+  for (auto &Entry : UseConvertTracker) {
+    for (auto &UserInst : Entry.second.Users) {
+      LLVM_DEBUG(dbgs() << *UserInst
+                        << "\n\tNow uses: " << *Entry.second.Converted << "\n");
+      UserInst->replaceUsesOfWith(&I, Entry.second.Converted);
+      MadeChange = true;
+    }
+  }
+  return MadeChange;
+}
+
+bool LiveRegOptimizer::replacePHIs() {
+  bool MadeChange = false;
+  for (auto Ele : PHIUpdater) {
+    auto ThePHINode = dyn_cast<PHINode>(Ele.first);
+    assert(ThePHINode);
+    auto NewPHINodeOps = Ele.second;
+    LLVM_DEBUG(dbgs() << "Attempting to replace: " << *ThePHINode << "\n");
+    // If we have conveted all the required operands, then do the replacement
+    if (ThePHINode->getNumIncomingValues() == NewPHINodeOps.size()) {
+      IRBuilder<> Builder(Ele.first);
+      auto NPHI = Builder.CreatePHI(NewPHINodeOps[0].first->getType(),
+                                    NewPHINodeOps.size());
+      for (auto IncVals : NewPHINodeOps) {
+        NPHI->addIncoming(IncVals.first, IncVals.second);
+        LLVM_DEBUG(dbgs() << "  Using: " << *IncVals.first
+                          << "  For: " << IncVals.second->getName() << "\n");
+      }
+      LLVM_DEBUG(dbgs() << "Sucessfully replaced with " << *NPHI << "\n");
+      LiveRegConversion ToLRC(NPHI, ThePHINode->getType(),
+                              ThePHINode->getParent(),
+                              static_cast<BasicBlock::iterator>(
+                                  ThePHINode->getParent()->getFirstNonPHIIt()));
+      convertFromOptType(ToLRC);
+      assert(ToLRC.hasConverted());
+      Ele.first->replaceAllUsesWith(*ToLRC.getConverted());
+      // The old PHI is no longer used
+      ThePHINode->eraseFromParent();
+      MadeChange = true;
+    }
+  }
   return MadeChange;
 }
 
+Type *LiveRegOptimizer::getCompatibleType(Instruction *InstToConvert) {
+  auto OriginalType = InstToConvert->getType();
+  assert(OriginalType->getScalarSizeInBits() <=
+         ConvertToScalar->getScalarSizeInBits());
+  auto VTy = dyn_cast<VectorType>(OriginalType);
+  if (!VTy)
+    return ConvertToScalar;
+
+  auto OriginalSize =
+      VTy->getScalarSizeInBits() * VTy->getElementCount().getFixedValue();
+  auto ConvertScalarSize = ConvertToScalar->getScalarSizeInBits();
+  auto ConvertEltCount =
+      (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
+
+  return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
+                         llvm::ElementCount::getFixed(ConvertEltCount));
+}
+
+void LiveRegOptimizer::convertToOptType(LiveRegConversion &LR) {
+  if (LR.hasConverted()) {
+    LLVM_DEBUG(dbgs() << "\tAlready has converted def\n");
+    return;
+  }
+
+  auto VTy = dyn_cast<VectorType>(LR.getOriginalType());
+  assert(VTy);
+  auto NewVTy = dyn_cast<VectorType>(LR.getNewType());
+  assert(NewVTy);
+
+  auto V = static_cast<Value *>(LR.getLiveRegDef());
+  auto OriginalSize =
+      VTy->getScalarSizeInBits() * VTy->getElementCount().getFixedValue();
+  auto NewSize =
+      NewVTy->getScalarSizeInBits() * NewVTy->getElementCount().getFixedValue();
+
+  auto &Builder = LR.getConverBuilder();
+
+  // If there is a bitsize match, we can fit the old vector into a new vector of
+  // desired type
+  if (OriginalSize == NewSize) {
+    LR.setConverted(dyn_cast<Instruction>(Builder.CreateBitCast(V, NewVTy)));
+    LLVM_DEBUG(dbgs() << "\tConverted def to "
+                      << *(*LR.getConverted())->getType() << "\n");
+    return;
+  }
+
+  // If there is a bitsize mismatch, we must use a wider vector
+  assert(NewSize > OriginalSize);
+  auto ExpandedVecElementCount =
+      llvm::ElementCount::getFixed(NewSize / VTy->getScalarSizeInBits());
+
+  SmallVector<int, 8> ShuffleMask;
+  for (unsigned I = 0; I < VTy->getElementCount().getFixedValue(); I++)
+    ShuffleMask.push_back(I);
+
+  for (uint64_t I = VTy->getElementCount().getFixedValue();
+       I < ExpandedVecElementCount.getFixedValue(); I++)
+    ShuffleMask.push_back(VTy->getElementCount().getFixedValue());
+
+  auto ExpandedVec =
+      dyn_cast<Instruction>(Builder.CreateShuffleVector(V, ShuffleMask));
+  LR.setConverted(
+      dyn_cast<Instruction>(Builder.CreateBitCast(ExpandedVec, NewVTy)));
+  LLVM_DEBUG(dbgs() << "\tConverted def to " << *(*LR.getConverted())->getType()
+                    << "\n");
+  return;
+}
+
+void LiveRegOptimizer::convertFromOptType(LiveRegConversion &LRC) {
+  auto VTy = dyn_cast<VectorType>(LRC.getOriginalType());
+  assert(VTy);
+  auto NewVTy = dyn_cast<VectorType>(LRC.getNewType());
+  assert(NewVTy);
+
+  auto V = static_cast<Value *>(LRC.getLiveRegDef());
+  auto OriginalSize =
+      VTy->getScalarSizeInBits() * VTy->getElementCount().getFixedValue();
+  auto NewSize =
+      NewVTy->getScalarSizeInBits() * NewVTy->getElementCount().getFixedValue();
+
+  auto &Builder = LRC.getConverBuilder();
+
+  // If there is a bitsize match, we simply convert back to the original type
+  if (OriginalSize == NewSize) {
+    LRC.setConverted(dyn_cast<Instruction>(Builder.CreateBitCast(V, NewVTy)));
+    LLVM_DEBUG(dbgs() << "\tProduced for user: " << **LRC.getConverted()
+                      << "\n");
+    return;
+  }
+
+  // If there is a bitsize mismatch, we have used a wider vector and must strip
+  // the MSBs to convert back to the original type
+  assert(OriginalSize > NewSize);
+  auto ExpandedVecElementCount = llvm::ElementCount::getFixed(
+      OriginalSize / NewVTy->getScalarSizeInBits());
+  auto ExpandedVT = VectorType::get(
+      Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
+      ExpandedVecElementCount);
+  auto Converted = dyn_cast<Instruction>(
+      Builder.CreateBitCast(LRC.getLiveRegDef(), ExpandedVT));
+
+  auto NarrowElementCount = NewVTy->getElementCount().getFixedValue();
+  SmallVector<int, 8> ShuffleMask;
+  for (uint64_t I = 0; I < NarrowElementCount; I++)
----------------
arsenm wrote:

can do this with the constructor 

https://github.com/llvm/llvm-project/pull/66838


More information about the llvm-commits mailing list