[llvm-branch-commits] [llvm] AMDGPU/GlobalISel: AMDGPURegBankLegalize (PR #112864)
Petar Avramovic via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Nov 5 08:08:38 PST 2024
================
@@ -69,11 +82,297 @@ FunctionPass *llvm::createAMDGPURegBankLegalizePass() {
return new AMDGPURegBankLegalize();
}
-using namespace AMDGPU;
+const RegBankLegalizeRules &getRules(const GCNSubtarget &ST,
+ MachineRegisterInfo &MRI) {
+ static std::mutex GlobalMutex;
+ static SmallDenseMap<unsigned, std::unique_ptr<RegBankLegalizeRules>>
+ CacheForRuleSet;
+ std::lock_guard<std::mutex> Lock(GlobalMutex);
+ if (!CacheForRuleSet.contains(ST.getGeneration())) {
+ auto Rules = std::make_unique<RegBankLegalizeRules>(ST, MRI);
+ CacheForRuleSet[ST.getGeneration()] = std::move(Rules);
+ } else {
+ CacheForRuleSet[ST.getGeneration()]->refreshRefs(ST, MRI);
+ }
+ return *CacheForRuleSet[ST.getGeneration()];
+}
+
+class AMDGPURegBankLegalizeCombiner {
+ MachineIRBuilder &B;
+ MachineRegisterInfo &MRI;
+ const SIRegisterInfo &TRI;
+ const RegisterBank *SgprRB;
+ const RegisterBank *VgprRB;
+ const RegisterBank *VccRB;
+
+ static constexpr LLT S1 = LLT::scalar(1);
+ static constexpr LLT S16 = LLT::scalar(16);
+ static constexpr LLT S32 = LLT::scalar(32);
+ static constexpr LLT S64 = LLT::scalar(64);
+
+public:
+ AMDGPURegBankLegalizeCombiner(MachineIRBuilder &B, const SIRegisterInfo &TRI,
+ const RegisterBankInfo &RBI)
+ : B(B), MRI(*B.getMRI()), TRI(TRI),
+ SgprRB(&RBI.getRegBank(AMDGPU::SGPRRegBankID)),
+ VgprRB(&RBI.getRegBank(AMDGPU::VGPRRegBankID)),
+ VccRB(&RBI.getRegBank(AMDGPU::VCCRegBankID)) {};
+
+ bool isLaneMask(Register Reg) {
+ const RegisterBank *RB = MRI.getRegBankOrNull(Reg);
+ if (RB && RB->getID() == AMDGPU::VCCRegBankID)
+ return true;
+
+ const TargetRegisterClass *RC = MRI.getRegClassOrNull(Reg);
+ return RC && TRI.isSGPRClass(RC) && MRI.getType(Reg) == LLT::scalar(1);
+ }
+
+ void cleanUpAfterCombine(MachineInstr &MI, MachineInstr *Optional0) {
+ MI.eraseFromParent();
+ if (Optional0 && isTriviallyDead(*Optional0, MRI))
+ Optional0->eraseFromParent();
+ }
+
+ std::pair<MachineInstr *, Register> tryMatch(Register Src, unsigned Opcode) {
+ MachineInstr *MatchMI = MRI.getVRegDef(Src);
+ if (MatchMI->getOpcode() != Opcode)
+ return {nullptr, Register()};
+ return {MatchMI, MatchMI->getOperand(1).getReg()};
+ }
+
+ void tryCombineCopy(MachineInstr &MI) {
+ using namespace llvm::MIPatternMatch;
+ Register Dst = MI.getOperand(0).getReg();
+ Register Src = MI.getOperand(1).getReg();
+ // Skip copies of physical registers.
+ if (!Dst.isVirtual() || !Src.isVirtual())
+ return;
+
+ // This is a cross bank copy, sgpr S1 to lane mask.
+ //
+ // %Src:sgpr(s1) = G_TRUNC %TruncS32Src:sgpr(s32)
+ // %Dst:lane-mask(s1) = COPY %Src:sgpr(s1)
+ // ->
+ // %Dst:lane-mask(s1) = G_COPY_VCC_SCC %TruncS32Src:sgpr(s32)
+ if (isLaneMask(Dst) && MRI.getRegBankOrNull(Src) == SgprRB) {
+ auto [Trunc, TruncS32Src] = tryMatch(Src, AMDGPU::G_TRUNC);
+ assert(Trunc && MRI.getType(TruncS32Src) == S32 &&
+ "sgpr S1 must be result of G_TRUNC of sgpr S32");
+
+ B.setInstr(MI);
+ // Ensure that truncated bits in BoolSrc are 0.
+ auto One = B.buildConstant({SgprRB, S32}, 1);
+ auto BoolSrc = B.buildAnd({SgprRB, S32}, TruncS32Src, One);
+ B.buildInstr(AMDGPU::G_COPY_VCC_SCC, {Dst}, {BoolSrc});
+ cleanUpAfterCombine(MI, Trunc);
+ return;
+ }
+
+ // Src = G_READANYLANE RALSrc
+ // Dst = COPY Src
+ // ->
+ // Dst = RALSrc
+ if (MRI.getRegBankOrNull(Dst) == VgprRB &&
+ MRI.getRegBankOrNull(Src) == SgprRB) {
+ auto [RAL, RALSrc] = tryMatch(Src, AMDGPU::G_READANYLANE);
+ if (!RAL)
+ return;
+
+ assert(MRI.getRegBank(RALSrc) == VgprRB);
+ MRI.replaceRegWith(Dst, RALSrc);
+ cleanUpAfterCombine(MI, RAL);
+ return;
+ }
+ }
+
+ void tryCombineS1AnyExt(MachineInstr &MI) {
+ // %Src:sgpr(S1) = G_TRUNC %TruncSrc
+ // %Dst = G_ANYEXT %Src:sgpr(S1)
+ // ->
+ // %Dst = G_... %TruncSrc
+ Register Dst = MI.getOperand(0).getReg();
+ Register Src = MI.getOperand(1).getReg();
+ if (MRI.getType(Src) != S1)
+ return;
+
+ auto [Trunc, TruncSrc] = tryMatch(Src, AMDGPU::G_TRUNC);
+ if (!Trunc)
+ return;
+
+ LLT DstTy = MRI.getType(Dst);
+ LLT TruncSrcTy = MRI.getType(TruncSrc);
+
+ if (DstTy == TruncSrcTy) {
+ MRI.replaceRegWith(Dst, TruncSrc);
+ cleanUpAfterCombine(MI, Trunc);
+ return;
+ }
+
+ B.setInstr(MI);
+
+ if (DstTy == S32 && TruncSrcTy == S64) {
+ auto Unmerge = B.buildUnmerge({SgprRB, S32}, TruncSrc);
+ MRI.replaceRegWith(Dst, Unmerge.getReg(0));
+ cleanUpAfterCombine(MI, Trunc);
+ return;
+ }
+
+ if (DstTy == S32 && TruncSrcTy == S16) {
+ B.buildAnyExt(Dst, TruncSrc);
+ cleanUpAfterCombine(MI, Trunc);
+ return;
+ }
+
+ if (DstTy == S16 && TruncSrcTy == S32) {
+ B.buildTrunc(Dst, TruncSrc);
+ cleanUpAfterCombine(MI, Trunc);
+ return;
+ }
+
+ llvm_unreachable("missing anyext + trunc combine");
+ }
+};
+
+static bool hasSgprS1(MachineFunction &MF, MachineRegisterInfo &MRI,
----------------
petar-avramovic wrote:
This is target specific, also should only be run between RegBankLegalize and instruction select not sure how to do that in machine verifier. Added [[maybe_unused]] for silence the compilers.
https://github.com/llvm/llvm-project/pull/112864
More information about the llvm-branch-commits
mailing list