[llvm] [NVPTX] Use appropriate operands in ReplaceImageHandles (NFC) (PR #127898)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 19 13:51:12 PST 2025


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/127898

>From 8730bc39da1ea76306fd1c5f5c202bb09a1cf98f Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 19 Feb 2025 21:21:06 +0000
Subject: [PATCH] [NVPTX] Use appropriate operands in ReplaceImageHandles (NFC)

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp     | 121 +++---------------
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h       |   6 +-
 .../Target/NVPTX/NVPTXMachineFunctionInfo.h   |   6 -
 .../Target/NVPTX/NVPTXReplaceImageHandles.cpp |  51 +++-----
 4 files changed, 38 insertions(+), 146 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index c8e29c1da6ec4..6e5dd6b15900c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -149,67 +149,6 @@ void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
   EmitToStreamer(*OutStreamer, Inst);
 }
 
-// Handle symbol backtracking for targets that do not support image handles
-bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
-                                           unsigned OpNo, MCOperand &MCOp) {
-  const MachineOperand &MO = MI->getOperand(OpNo);
-  const MCInstrDesc &MCID = MI->getDesc();
-
-  if (MCID.TSFlags & NVPTXII::IsTexFlag) {
-    // This is a texture fetch, so operand 4 is a texref and operand 5 is
-    // a samplerref
-    if (OpNo == 4 && MO.isImm()) {
-      lowerImageHandleSymbol(MO.getImm(), MCOp);
-      return true;
-    }
-    if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
-      lowerImageHandleSymbol(MO.getImm(), MCOp);
-      return true;
-    }
-
-    return false;
-  } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
-    unsigned VecSize =
-      1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
-
-    // For a surface load of vector size N, the Nth operand will be the surfref
-    if (OpNo == VecSize && MO.isImm()) {
-      lowerImageHandleSymbol(MO.getImm(), MCOp);
-      return true;
-    }
-
-    return false;
-  } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
-    // This is a surface store, so operand 0 is a surfref
-    if (OpNo == 0 && MO.isImm()) {
-      lowerImageHandleSymbol(MO.getImm(), MCOp);
-      return true;
-    }
-
-    return false;
-  } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
-    // This is a query, so operand 1 is a surfref/texref
-    if (OpNo == 1 && MO.isImm()) {
-      lowerImageHandleSymbol(MO.getImm(), MCOp);
-      return true;
-    }
-
-    return false;
-  }
-
-  return false;
-}
-
-void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
-  // Ewwww
-  TargetMachine &TM = const_cast<TargetMachine &>(MF->getTarget());
-  NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine &>(TM);
-  const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
-  StringRef Sym = MFI->getImageHandleSymbol(Index);
-  StringRef SymName = nvTM.getStrPool().save(Sym);
-  MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
-}
-
 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
   OutMI.setOpcode(MI->getOpcode());
   // Special: Do not mangle symbol operand of CALL_PROTOTYPE
@@ -220,67 +159,49 @@ void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
     return;
   }
 
-  for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
-    const MachineOperand &MO = MI->getOperand(i);
-
-    MCOperand MCOp;
-    if (lowerImageHandleOperand(MI, i, MCOp)) {
-      OutMI.addOperand(MCOp);
-      continue;
-    }
-
-    if (lowerOperand(MO, MCOp))
-      OutMI.addOperand(MCOp);
-  }
+  for (const auto MO : MI->operands())
+    OutMI.addOperand(lowerOperand(MO));
 }
 
-bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
-                                   MCOperand &MCOp) {
+MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
   switch (MO.getType()) {
-  default: llvm_unreachable("unknown operand type");
+  default:
+    llvm_unreachable("unknown operand type");
   case MachineOperand::MO_Register:
-    MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
-    break;
+    return MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
   case MachineOperand::MO_Immediate:
-    MCOp = MCOperand::createImm(MO.getImm());
-    break;
+    return MCOperand::createImm(MO.getImm());
   case MachineOperand::MO_MachineBasicBlock:
-    MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
-        MO.getMBB()->getSymbol(), OutContext));
-    break;
+    return MCOperand::createExpr(
+        MCSymbolRefExpr::create(MO.getMBB()->getSymbol(), OutContext));
   case MachineOperand::MO_ExternalSymbol:
-    MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
-    break;
+    return GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
   case MachineOperand::MO_GlobalAddress:
-    MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
-    break;
+    return GetSymbolRef(getSymbol(MO.getGlobal()));
   case MachineOperand::MO_FPImmediate: {
     const ConstantFP *Cnt = MO.getFPImm();
     const APFloat &Val = Cnt->getValueAPF();
 
     switch (Cnt->getType()->getTypeID()) {
-    default: report_fatal_error("Unsupported FP type"); break;
-    case Type::HalfTyID:
-      MCOp = MCOperand::createExpr(
-        NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
+    default:
+      report_fatal_error("Unsupported FP type");
       break;
+    case Type::HalfTyID:
+      return MCOperand::createExpr(
+          NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
     case Type::BFloatTyID:
-      MCOp = MCOperand::createExpr(
+      return MCOperand::createExpr(
           NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
-      break;
     case Type::FloatTyID:
-      MCOp = MCOperand::createExpr(
-        NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
-      break;
+      return MCOperand::createExpr(
+          NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
     case Type::DoubleTyID:
-      MCOp = MCOperand::createExpr(
-        NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
-      break;
+      return MCOperand::createExpr(
+          NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
     }
     break;
   }
   }
-  return true;
 }
 
 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index f7c3fda332eff..74daaa2fb7134 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -163,7 +163,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
 
   void emitInstruction(const MachineInstr *) override;
   void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI);
-  bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp);
+  MCOperand lowerOperand(const MachineOperand &MO);
   MCOperand GetSymbolRef(const MCSymbol *Symbol);
   unsigned encodeVirtualRegister(unsigned Reg);
 
@@ -226,10 +226,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
   void emitDeclarationWithName(const Function *, MCSymbol *, raw_ostream &O);
   void emitDemotedVars(const Function *, raw_ostream &);
 
-  bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo,
-                               MCOperand &MCOp);
-  void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp);
-
   bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const;
 
   // Used to control the need to emit .generic() in the initializer of
diff --git a/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h b/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
index 6670cb296f216..d9beab7ec42e1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
@@ -47,12 +47,6 @@ class NVPTXMachineFunctionInfo : public MachineFunctionInfo {
     return ImageHandleList.size()-1;
   }
 
-  /// Returns the symbol name at the given index.
-  StringRef getImageHandleSymbol(unsigned Idx) const {
-    assert(ImageHandleList.size() > Idx && "Bad index");
-    return ImageHandleList[Idx];
-  }
-
   /// Check if the symbol has a mapping. Having a mapping means the handle is
   /// replaced with a reference
   bool checkImageHandleSymbol(StringRef Symbol) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
index a3e3978cbbfe2..4d0694faa0c9a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
@@ -20,7 +20,6 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
-#include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
 
@@ -41,10 +40,8 @@ class NVPTXReplaceImageHandles : public MachineFunctionPass {
 private:
   bool processInstr(MachineInstr &MI);
   bool replaceImageHandle(MachineOperand &Op, MachineFunction &MF);
-  bool findIndexForHandle(MachineOperand &Op, MachineFunction &MF,
-                          unsigned &Idx);
 };
-}
+} // namespace
 
 char NVPTXReplaceImageHandles::ID = 0;
 
@@ -1756,9 +1753,11 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
     }
 
     return true;
-  } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
+  }
+  if (MCID.TSFlags & NVPTXII::IsSuldMask) {
     unsigned VecSize =
-      1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
+        1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) -
+              1);
 
     // For a surface load of vector size N, the Nth operand will be the surfref
     MachineOperand &SurfHandle = MI.getOperand(VecSize);
@@ -1767,7 +1766,8 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
       MI.setDesc(TII->get(suldRegisterToIndexOpcode(MI.getOpcode())));
 
     return true;
-  } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
+  }
+  if (MCID.TSFlags & NVPTXII::IsSustFlag) {
     // This is a surface store, so operand 0 is a surfref
     MachineOperand &SurfHandle = MI.getOperand(0);
 
@@ -1775,7 +1775,8 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
       MI.setDesc(TII->get(sustRegisterToIndexOpcode(MI.getOpcode())));
 
     return true;
-  } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
+  }
+  if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
     // This is a query, so operand 1 is a surfref/texref
     MachineOperand &Handle = MI.getOperand(1);
 
@@ -1790,16 +1791,6 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
 
 bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
                                                   MachineFunction &MF) {
-  unsigned Idx;
-  if (findIndexForHandle(Op, MF, Idx)) {
-    Op.ChangeToImmediate(Idx);
-    return true;
-  }
-  return false;
-}
-
-bool NVPTXReplaceImageHandles::
-findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
   const MachineRegisterInfo &MRI = MF.getRegInfo();
   NVPTXMachineFunctionInfo *MFI = MF.getInfo<NVPTXMachineFunctionInfo>();
 
@@ -1812,25 +1803,16 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
   case NVPTX::LD_i64_avar: {
     // The handle is a parameter value being loaded, replace with the
     // parameter symbol
-    const NVPTXTargetMachine &TM =
-        static_cast<const NVPTXTargetMachine &>(MF.getTarget());
-    if (TM.getDrvInterface() == NVPTX::CUDA) {
+    const auto &TM = static_cast<const NVPTXTargetMachine &>(MF.getTarget());
+    if (TM.getDrvInterface() == NVPTX::CUDA)
       // For CUDA, we preserve the param loads coming from function arguments
       return false;
-    }
 
     assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
     StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
-    std::string ParamBaseName = std::string(MF.getName());
-    ParamBaseName += "_param_";
-    assert(Sym.starts_with(ParamBaseName) && "Invalid symbol reference");
-    unsigned Param = atoi(Sym.data()+ParamBaseName.size());
-    std::string NewSym;
-    raw_string_ostream NewSymStr(NewSym);
-    NewSymStr << MF.getName() << "_param_" << Param;
-
     InstrsToRemove.insert(&TexHandleDef);
-    Idx = MFI->getImageHandleSymbolIndex(NewSymStr.str());
+    Op.ChangeToES(Sym.data());
+    MFI->getImageHandleSymbolIndex(Sym);
     return true;
   }
   case NVPTX::texsurf_handles: {
@@ -1839,15 +1821,14 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
     const GlobalValue *GV = TexHandleDef.getOperand(1).getGlobal();
     assert(GV->hasName() && "Global sampler must be named!");
     InstrsToRemove.insert(&TexHandleDef);
-    Idx = MFI->getImageHandleSymbolIndex(GV->getName());
+    Op.ChangeToGA(GV, 0);
     return true;
   }
   case NVPTX::nvvm_move_i64:
   case TargetOpcode::COPY: {
-    bool Res = findIndexForHandle(TexHandleDef.getOperand(1), MF, Idx);
-    if (Res) {
+    bool Res = replaceImageHandle(TexHandleDef.getOperand(1), MF);
+    if (Res)
       InstrsToRemove.insert(&TexHandleDef);
-    }
     return Res;
   }
   default:



More information about the llvm-commits mailing list