[llvm] 08e5a1d - [llvm][NVPTX] Fix quadratic runtime in ProxyRegErasure (#105730)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 22 17:23:50 PDT 2024


Author: Jeff Niu
Date: 2024-08-22T20:23:44-04:00
New Revision: 08e5a1de8227512d4774a534b91cb2353cef6284

URL: https://github.com/llvm/llvm-project/commit/08e5a1de8227512d4774a534b91cb2353cef6284
DIFF: https://github.com/llvm/llvm-project/commit/08e5a1de8227512d4774a534b91cb2353cef6284.diff

LOG: [llvm][NVPTX] Fix quadratic runtime in ProxyRegErasure (#105730)

This pass performs RAUW by walking the machine function for each RAUW
operation. For large functions, this runtime in this pass starts to blow
up. Linearize the pass by batching the RAUW ops at once.

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp b/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp
index 258ae97a20d582..f3a3362addb0ea 100644
--- a/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp
@@ -34,7 +34,6 @@ void initializeNVPTXProxyRegErasurePass(PassRegistry &);
 namespace {
 
 struct NVPTXProxyRegErasure : public MachineFunctionPass {
-public:
   static char ID;
   NVPTXProxyRegErasure() : MachineFunctionPass(ID) {
     initializeNVPTXProxyRegErasurePass(*PassRegistry::getPassRegistry());
@@ -49,23 +48,22 @@ struct NVPTXProxyRegErasure : public MachineFunctionPass {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     MachineFunctionPass::getAnalysisUsage(AU);
   }
-
-private:
-  void replaceMachineInstructionUsage(MachineFunction &MF, MachineInstr &MI);
-
-  void replaceRegisterUsage(MachineInstr &Instr, MachineOperand &From,
-                            MachineOperand &To);
 };
 
 } // namespace
 
 char NVPTXProxyRegErasure::ID = 0;
 
-INITIALIZE_PASS(NVPTXProxyRegErasure, "nvptx-proxyreg-erasure", "NVPTX ProxyReg Erasure", false, false)
+INITIALIZE_PASS(NVPTXProxyRegErasure, "nvptx-proxyreg-erasure",
+                "NVPTX ProxyReg Erasure", false, false)
 
 bool NVPTXProxyRegErasure::runOnMachineFunction(MachineFunction &MF) {
   SmallVector<MachineInstr *, 16> RemoveList;
 
+  // ProxyReg instructions forward a register as another: `%dst = mov.iN %src`.
+  // Bulk RAUW the `%dst` registers in two passes over the machine function.
+  DenseMap<Register, Register> RAUWBatch;
+
   for (auto &BB : MF) {
     for (auto &MI : BB) {
       switch (MI.getOpcode()) {
@@ -74,44 +72,42 @@ bool NVPTXProxyRegErasure::runOnMachineFunction(MachineFunction &MF) {
       case NVPTX::ProxyRegI32:
       case NVPTX::ProxyRegI64:
       case NVPTX::ProxyRegF32:
-      case NVPTX::ProxyRegF64:
-        replaceMachineInstructionUsage(MF, MI);
+      case NVPTX::ProxyRegF64: {
+        auto &InOp = *MI.uses().begin();
+        auto &OutOp = *MI.defs().begin();
+        assert(InOp.isReg() && "ProxyReg input should be a register.");
+        assert(OutOp.isReg() && "ProxyReg output should be a register.");
         RemoveList.push_back(&MI);
+        RAUWBatch.try_emplace(OutOp.getReg(), InOp.getReg());
         break;
       }
+      }
     }
   }
 
+  // If there were no proxy instructions, exit early.
+  if (RemoveList.empty())
+    return false;
+
+  // Erase the proxy instructions first.
   for (auto *MI : RemoveList) {
     MI->eraseFromParent();
   }
 
-  return !RemoveList.empty();
-}
-
-void NVPTXProxyRegErasure::replaceMachineInstructionUsage(MachineFunction &MF,
-                                                          MachineInstr &MI) {
-  auto &InOp = *MI.uses().begin();
-  auto &OutOp = *MI.defs().begin();
-
-  assert(InOp.isReg() && "ProxyReg input operand should be a register.");
-  assert(OutOp.isReg() && "ProxyReg output operand should be a register.");
-
+  // Now go replace the registers.
   for (auto &BB : MF) {
-    for (auto &I : BB) {
-      replaceRegisterUsage(I, OutOp, InOp);
+    for (auto &MI : BB) {
+      for (auto &Op : MI.uses()) {
+        if (!Op.isReg())
+          continue;
+        auto it = RAUWBatch.find(Op.getReg());
+        if (it != RAUWBatch.end())
+          Op.setReg(it->second);
+      }
     }
   }
-}
 
-void NVPTXProxyRegErasure::replaceRegisterUsage(MachineInstr &Instr,
-                                                MachineOperand &From,
-                                                MachineOperand &To) {
-  for (auto &Op : Instr.uses()) {
-    if (Op.isReg() && Op.getReg() == From.getReg()) {
-      Op.setReg(To.getReg());
-    }
-  }
+  return true;
 }
 
 MachineFunctionPass *llvm::createNVPTXProxyRegErasurePass() {


        


More information about the llvm-commits mailing list