[llvm] 62e4436 - [NVPTX] Use appropriate operands in ReplaceImageHandles (NFC) (#127898)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 19 20:59:35 PST 2025


Author: Alex MacLean
Date: 2025-02-19T20:59:31-08:00
New Revision: 62e4436bb80b82c836fb747b391696700f089e6b

URL: https://github.com/llvm/llvm-project/commit/62e4436bb80b82c836fb747b391696700f089e6b
DIFF: https://github.com/llvm/llvm-project/commit/62e4436bb80b82c836fb747b391696700f089e6b.diff

LOG: [NVPTX] Use appropriate operands in ReplaceImageHandles (NFC) (#127898)

Prior to this change NVPTXReplaceImageHandles replaced operands with
indices and populated a table matching these indices to strings to be
used in AsmPrinter. We can clean this up by simply inserting the correct
external symbol or global address operands during
NVPTXReplaceImageHandles, largely removing the need for the table.

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
    llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
    llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index c8e29c1da6ec4..6e5dd6b15900c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -149,67 +149,6 @@ void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
   EmitToStreamer(*OutStreamer, Inst);
 }
 
-// Handle symbol backtracking for targets that do not support image handles
-bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
-                                           unsigned OpNo, MCOperand &MCOp) {
-  const MachineOperand &MO = MI->getOperand(OpNo);
-  const MCInstrDesc &MCID = MI->getDesc();
-
-  if (MCID.TSFlags & NVPTXII::IsTexFlag) {
-    // This is a texture fetch, so operand 4 is a texref and operand 5 is
-    // a samplerref
-    if (OpNo == 4 && MO.isImm()) {
-      lowerImageHandleSymbol(MO.getImm(), MCOp);
-      return true;
-    }
-    if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
-      lowerImageHandleSymbol(MO.getImm(), MCOp);
-      return true;
-    }
-
-    return false;
-  } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
-    unsigned VecSize =
-      1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
-
-    // For a surface load of vector size N, the Nth operand will be the surfref
-    if (OpNo == VecSize && MO.isImm()) {
-      lowerImageHandleSymbol(MO.getImm(), MCOp);
-      return true;
-    }
-
-    return false;
-  } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
-    // This is a surface store, so operand 0 is a surfref
-    if (OpNo == 0 && MO.isImm()) {
-      lowerImageHandleSymbol(MO.getImm(), MCOp);
-      return true;
-    }
-
-    return false;
-  } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
-    // This is a query, so operand 1 is a surfref/texref
-    if (OpNo == 1 && MO.isImm()) {
-      lowerImageHandleSymbol(MO.getImm(), MCOp);
-      return true;
-    }
-
-    return false;
-  }
-
-  return false;
-}
-
-void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
-  // Ewwww
-  TargetMachine &TM = const_cast<TargetMachine &>(MF->getTarget());
-  NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine &>(TM);
-  const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
-  StringRef Sym = MFI->getImageHandleSymbol(Index);
-  StringRef SymName = nvTM.getStrPool().save(Sym);
-  MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
-}
-
 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
   OutMI.setOpcode(MI->getOpcode());
   // Special: Do not mangle symbol operand of CALL_PROTOTYPE
@@ -220,67 +159,49 @@ void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
     return;
   }
 
-  for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
-    const MachineOperand &MO = MI->getOperand(i);
-
-    MCOperand MCOp;
-    if (lowerImageHandleOperand(MI, i, MCOp)) {
-      OutMI.addOperand(MCOp);
-      continue;
-    }
-
-    if (lowerOperand(MO, MCOp))
-      OutMI.addOperand(MCOp);
-  }
+  for (const auto MO : MI->operands())
+    OutMI.addOperand(lowerOperand(MO));
 }
 
-bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
-                                   MCOperand &MCOp) {
+MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
   switch (MO.getType()) {
-  default: llvm_unreachable("unknown operand type");
+  default:
+    llvm_unreachable("unknown operand type");
   case MachineOperand::MO_Register:
-    MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
-    break;
+    return MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
   case MachineOperand::MO_Immediate:
-    MCOp = MCOperand::createImm(MO.getImm());
-    break;
+    return MCOperand::createImm(MO.getImm());
   case MachineOperand::MO_MachineBasicBlock:
-    MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
-        MO.getMBB()->getSymbol(), OutContext));
-    break;
+    return MCOperand::createExpr(
+        MCSymbolRefExpr::create(MO.getMBB()->getSymbol(), OutContext));
   case MachineOperand::MO_ExternalSymbol:
-    MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
-    break;
+    return GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
   case MachineOperand::MO_GlobalAddress:
-    MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
-    break;
+    return GetSymbolRef(getSymbol(MO.getGlobal()));
   case MachineOperand::MO_FPImmediate: {
     const ConstantFP *Cnt = MO.getFPImm();
     const APFloat &Val = Cnt->getValueAPF();
 
     switch (Cnt->getType()->getTypeID()) {
-    default: report_fatal_error("Unsupported FP type"); break;
-    case Type::HalfTyID:
-      MCOp = MCOperand::createExpr(
-        NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
+    default:
+      report_fatal_error("Unsupported FP type");
       break;
+    case Type::HalfTyID:
+      return MCOperand::createExpr(
+          NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
     case Type::BFloatTyID:
-      MCOp = MCOperand::createExpr(
+      return MCOperand::createExpr(
           NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
-      break;
     case Type::FloatTyID:
-      MCOp = MCOperand::createExpr(
-        NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
-      break;
+      return MCOperand::createExpr(
+          NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
     case Type::DoubleTyID:
-      MCOp = MCOperand::createExpr(
-        NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
-      break;
+      return MCOperand::createExpr(
+          NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
     }
     break;
   }
   }
-  return true;
 }
 
 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {

diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index f7c3fda332eff..74daaa2fb7134 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -163,7 +163,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
 
   void emitInstruction(const MachineInstr *) override;
   void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI);
-  bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp);
+  MCOperand lowerOperand(const MachineOperand &MO);
   MCOperand GetSymbolRef(const MCSymbol *Symbol);
   unsigned encodeVirtualRegister(unsigned Reg);
 
@@ -226,10 +226,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
   void emitDeclarationWithName(const Function *, MCSymbol *, raw_ostream &O);
   void emitDemotedVars(const Function *, raw_ostream &);
 
-  bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo,
-                               MCOperand &MCOp);
-  void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp);
-
   bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const;
 
   // Used to control the need to emit .generic() in the initializer of

diff  --git a/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h b/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
index 6670cb296f216..d9beab7ec42e1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
@@ -47,12 +47,6 @@ class NVPTXMachineFunctionInfo : public MachineFunctionInfo {
     return ImageHandleList.size()-1;
   }
 
-  /// Returns the symbol name at the given index.
-  StringRef getImageHandleSymbol(unsigned Idx) const {
-    assert(ImageHandleList.size() > Idx && "Bad index");
-    return ImageHandleList[Idx];
-  }
-
   /// Check if the symbol has a mapping. Having a mapping means the handle is
   /// replaced with a reference
   bool checkImageHandleSymbol(StringRef Symbol) const {

diff  --git a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
index a3e3978cbbfe2..4d0694faa0c9a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
@@ -20,7 +20,6 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
-#include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
 
@@ -41,10 +40,8 @@ class NVPTXReplaceImageHandles : public MachineFunctionPass {
 private:
   bool processInstr(MachineInstr &MI);
   bool replaceImageHandle(MachineOperand &Op, MachineFunction &MF);
-  bool findIndexForHandle(MachineOperand &Op, MachineFunction &MF,
-                          unsigned &Idx);
 };
-}
+} // namespace
 
 char NVPTXReplaceImageHandles::ID = 0;
 
@@ -1756,9 +1753,11 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
     }
 
     return true;
-  } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
+  }
+  if (MCID.TSFlags & NVPTXII::IsSuldMask) {
     unsigned VecSize =
-      1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
+        1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) -
+              1);
 
     // For a surface load of vector size N, the Nth operand will be the surfref
     MachineOperand &SurfHandle = MI.getOperand(VecSize);
@@ -1767,7 +1766,8 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
       MI.setDesc(TII->get(suldRegisterToIndexOpcode(MI.getOpcode())));
 
     return true;
-  } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
+  }
+  if (MCID.TSFlags & NVPTXII::IsSustFlag) {
     // This is a surface store, so operand 0 is a surfref
     MachineOperand &SurfHandle = MI.getOperand(0);
 
@@ -1775,7 +1775,8 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
       MI.setDesc(TII->get(sustRegisterToIndexOpcode(MI.getOpcode())));
 
     return true;
-  } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
+  }
+  if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
     // This is a query, so operand 1 is a surfref/texref
     MachineOperand &Handle = MI.getOperand(1);
 
@@ -1790,16 +1791,6 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
 
 bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
                                                   MachineFunction &MF) {
-  unsigned Idx;
-  if (findIndexForHandle(Op, MF, Idx)) {
-    Op.ChangeToImmediate(Idx);
-    return true;
-  }
-  return false;
-}
-
-bool NVPTXReplaceImageHandles::
-findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
   const MachineRegisterInfo &MRI = MF.getRegInfo();
   NVPTXMachineFunctionInfo *MFI = MF.getInfo<NVPTXMachineFunctionInfo>();
 
@@ -1812,25 +1803,16 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
   case NVPTX::LD_i64_avar: {
     // The handle is a parameter value being loaded, replace with the
     // parameter symbol
-    const NVPTXTargetMachine &TM =
-        static_cast<const NVPTXTargetMachine &>(MF.getTarget());
-    if (TM.getDrvInterface() == NVPTX::CUDA) {
+    const auto &TM = static_cast<const NVPTXTargetMachine &>(MF.getTarget());
+    if (TM.getDrvInterface() == NVPTX::CUDA)
       // For CUDA, we preserve the param loads coming from function arguments
       return false;
-    }
 
     assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
     StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
-    std::string ParamBaseName = std::string(MF.getName());
-    ParamBaseName += "_param_";
-    assert(Sym.starts_with(ParamBaseName) && "Invalid symbol reference");
-    unsigned Param = atoi(Sym.data()+ParamBaseName.size());
-    std::string NewSym;
-    raw_string_ostream NewSymStr(NewSym);
-    NewSymStr << MF.getName() << "_param_" << Param;
-
     InstrsToRemove.insert(&TexHandleDef);
-    Idx = MFI->getImageHandleSymbolIndex(NewSymStr.str());
+    Op.ChangeToES(Sym.data());
+    MFI->getImageHandleSymbolIndex(Sym);
     return true;
   }
   case NVPTX::texsurf_handles: {
@@ -1839,15 +1821,14 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
     const GlobalValue *GV = TexHandleDef.getOperand(1).getGlobal();
     assert(GV->hasName() && "Global sampler must be named!");
     InstrsToRemove.insert(&TexHandleDef);
-    Idx = MFI->getImageHandleSymbolIndex(GV->getName());
+    Op.ChangeToGA(GV, 0);
     return true;
   }
   case NVPTX::nvvm_move_i64:
   case TargetOpcode::COPY: {
-    bool Res = findIndexForHandle(TexHandleDef.getOperand(1), MF, Idx);
-    if (Res) {
+    bool Res = replaceImageHandle(TexHandleDef.getOperand(1), MF);
+    if (Res)
       InstrsToRemove.insert(&TexHandleDef);
-    }
     return Res;
   }
   default:


        


More information about the llvm-commits mailing list