[llvm] [NVPTX] Use appropriate operands in ReplaceImageHandles (NFC) (PR #127898)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 19 13:22:09 PST 2025
https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/127898
Prior to this change NVPTXReplaceImageHandles replaced operands with indices and populated a table matching these indices to strings to be used in AsmPrinter. We can clean this up by simply inserting the correct external symbol or global address operands during NVPTXReplaceImageHandles, largely removing the need for the table.
>From 965c57cf5b73b18afb166448d031e2d3e0dfb240 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 19 Feb 2025 21:21:06 +0000
Subject: [PATCH] [NVPTX] Use appropriate operands in ReplaceImageHandles (NFC)
---
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 120 +++---------------
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h | 6 +-
.../Target/NVPTX/NVPTXMachineFunctionInfo.h | 6 -
.../Target/NVPTX/NVPTXReplaceImageHandles.cpp | 51 +++-----
4 files changed, 38 insertions(+), 145 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index c8e29c1da6ec4..b36e677d5e67a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -149,66 +149,6 @@ void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
EmitToStreamer(*OutStreamer, Inst);
}
-// Handle symbol backtracking for targets that do not support image handles
-bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
- unsigned OpNo, MCOperand &MCOp) {
- const MachineOperand &MO = MI->getOperand(OpNo);
- const MCInstrDesc &MCID = MI->getDesc();
-
- if (MCID.TSFlags & NVPTXII::IsTexFlag) {
- // This is a texture fetch, so operand 4 is a texref and operand 5 is
- // a samplerref
- if (OpNo == 4 && MO.isImm()) {
- lowerImageHandleSymbol(MO.getImm(), MCOp);
- return true;
- }
- if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
- lowerImageHandleSymbol(MO.getImm(), MCOp);
- return true;
- }
-
- return false;
- } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
- unsigned VecSize =
- 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
-
- // For a surface load of vector size N, the Nth operand will be the surfref
- if (OpNo == VecSize && MO.isImm()) {
- lowerImageHandleSymbol(MO.getImm(), MCOp);
- return true;
- }
-
- return false;
- } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
- // This is a surface store, so operand 0 is a surfref
- if (OpNo == 0 && MO.isImm()) {
- lowerImageHandleSymbol(MO.getImm(), MCOp);
- return true;
- }
-
- return false;
- } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
- // This is a query, so operand 1 is a surfref/texref
- if (OpNo == 1 && MO.isImm()) {
- lowerImageHandleSymbol(MO.getImm(), MCOp);
- return true;
- }
-
- return false;
- }
-
- return false;
-}
-
-void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
- // Ewwww
- TargetMachine &TM = const_cast<TargetMachine &>(MF->getTarget());
- NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine &>(TM);
- const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
- StringRef Sym = MFI->getImageHandleSymbol(Index);
- StringRef SymName = nvTM.getStrPool().save(Sym);
- MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
-}
void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
OutMI.setOpcode(MI->getOpcode());
@@ -220,67 +160,49 @@ void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
return;
}
- for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
- const MachineOperand &MO = MI->getOperand(i);
-
- MCOperand MCOp;
- if (lowerImageHandleOperand(MI, i, MCOp)) {
- OutMI.addOperand(MCOp);
- continue;
- }
-
- if (lowerOperand(MO, MCOp))
- OutMI.addOperand(MCOp);
- }
+ for (const auto MO : MI->operands())
+ OutMI.addOperand(lowerOperand(MO));
}
-bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
- MCOperand &MCOp) {
+MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
switch (MO.getType()) {
- default: llvm_unreachable("unknown operand type");
+ default:
+ llvm_unreachable("unknown operand type");
case MachineOperand::MO_Register:
- MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
- break;
+ return MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
case MachineOperand::MO_Immediate:
- MCOp = MCOperand::createImm(MO.getImm());
- break;
+ return MCOperand::createImm(MO.getImm());
case MachineOperand::MO_MachineBasicBlock:
- MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
- MO.getMBB()->getSymbol(), OutContext));
- break;
+ return MCOperand::createExpr(
+ MCSymbolRefExpr::create(MO.getMBB()->getSymbol(), OutContext));
case MachineOperand::MO_ExternalSymbol:
- MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
- break;
+ return GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
case MachineOperand::MO_GlobalAddress:
- MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
- break;
+ return GetSymbolRef(getSymbol(MO.getGlobal()));
case MachineOperand::MO_FPImmediate: {
const ConstantFP *Cnt = MO.getFPImm();
const APFloat &Val = Cnt->getValueAPF();
switch (Cnt->getType()->getTypeID()) {
- default: report_fatal_error("Unsupported FP type"); break;
- case Type::HalfTyID:
- MCOp = MCOperand::createExpr(
- NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
+ default:
+ report_fatal_error("Unsupported FP type");
break;
+ case Type::HalfTyID:
+ return MCOperand::createExpr(
+ NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
case Type::BFloatTyID:
- MCOp = MCOperand::createExpr(
+ return MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
- break;
case Type::FloatTyID:
- MCOp = MCOperand::createExpr(
- NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
- break;
+ return MCOperand::createExpr(
+ NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
case Type::DoubleTyID:
- MCOp = MCOperand::createExpr(
- NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
- break;
+ return MCOperand::createExpr(
+ NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
}
break;
}
}
- return true;
}
unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index f7c3fda332eff..74daaa2fb7134 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -163,7 +163,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
void emitInstruction(const MachineInstr *) override;
void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI);
- bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp);
+ MCOperand lowerOperand(const MachineOperand &MO);
MCOperand GetSymbolRef(const MCSymbol *Symbol);
unsigned encodeVirtualRegister(unsigned Reg);
@@ -226,10 +226,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
void emitDeclarationWithName(const Function *, MCSymbol *, raw_ostream &O);
void emitDemotedVars(const Function *, raw_ostream &);
- bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo,
- MCOperand &MCOp);
- void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp);
-
bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const;
// Used to control the need to emit .generic() in the initializer of
diff --git a/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h b/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
index 6670cb296f216..d9beab7ec42e1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
@@ -47,12 +47,6 @@ class NVPTXMachineFunctionInfo : public MachineFunctionInfo {
return ImageHandleList.size()-1;
}
- /// Returns the symbol name at the given index.
- StringRef getImageHandleSymbol(unsigned Idx) const {
- assert(ImageHandleList.size() > Idx && "Bad index");
- return ImageHandleList[Idx];
- }
-
/// Check if the symbol has a mapping. Having a mapping means the handle is
/// replaced with a reference
bool checkImageHandleSymbol(StringRef Symbol) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
index a3e3978cbbfe2..4d0694faa0c9a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
@@ -20,7 +20,6 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
-#include "llvm/Support/raw_ostream.h"
using namespace llvm;
@@ -41,10 +40,8 @@ class NVPTXReplaceImageHandles : public MachineFunctionPass {
private:
bool processInstr(MachineInstr &MI);
bool replaceImageHandle(MachineOperand &Op, MachineFunction &MF);
- bool findIndexForHandle(MachineOperand &Op, MachineFunction &MF,
- unsigned &Idx);
};
-}
+} // namespace
char NVPTXReplaceImageHandles::ID = 0;
@@ -1756,9 +1753,11 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
}
return true;
- } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
+ }
+ if (MCID.TSFlags & NVPTXII::IsSuldMask) {
unsigned VecSize =
- 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
+ 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) -
+ 1);
// For a surface load of vector size N, the Nth operand will be the surfref
MachineOperand &SurfHandle = MI.getOperand(VecSize);
@@ -1767,7 +1766,8 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
MI.setDesc(TII->get(suldRegisterToIndexOpcode(MI.getOpcode())));
return true;
- } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
+ }
+ if (MCID.TSFlags & NVPTXII::IsSustFlag) {
// This is a surface store, so operand 0 is a surfref
MachineOperand &SurfHandle = MI.getOperand(0);
@@ -1775,7 +1775,8 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
MI.setDesc(TII->get(sustRegisterToIndexOpcode(MI.getOpcode())));
return true;
- } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
+ }
+ if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
// This is a query, so operand 1 is a surfref/texref
MachineOperand &Handle = MI.getOperand(1);
@@ -1790,16 +1791,6 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
MachineFunction &MF) {
- unsigned Idx;
- if (findIndexForHandle(Op, MF, Idx)) {
- Op.ChangeToImmediate(Idx);
- return true;
- }
- return false;
-}
-
-bool NVPTXReplaceImageHandles::
-findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
const MachineRegisterInfo &MRI = MF.getRegInfo();
NVPTXMachineFunctionInfo *MFI = MF.getInfo<NVPTXMachineFunctionInfo>();
@@ -1812,25 +1803,16 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
case NVPTX::LD_i64_avar: {
// The handle is a parameter value being loaded, replace with the
// parameter symbol
- const NVPTXTargetMachine &TM =
- static_cast<const NVPTXTargetMachine &>(MF.getTarget());
- if (TM.getDrvInterface() == NVPTX::CUDA) {
+ const auto &TM = static_cast<const NVPTXTargetMachine &>(MF.getTarget());
+ if (TM.getDrvInterface() == NVPTX::CUDA)
// For CUDA, we preserve the param loads coming from function arguments
return false;
- }
assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
- std::string ParamBaseName = std::string(MF.getName());
- ParamBaseName += "_param_";
- assert(Sym.starts_with(ParamBaseName) && "Invalid symbol reference");
- unsigned Param = atoi(Sym.data()+ParamBaseName.size());
- std::string NewSym;
- raw_string_ostream NewSymStr(NewSym);
- NewSymStr << MF.getName() << "_param_" << Param;
-
InstrsToRemove.insert(&TexHandleDef);
- Idx = MFI->getImageHandleSymbolIndex(NewSymStr.str());
+ Op.ChangeToES(Sym.data());
+ MFI->getImageHandleSymbolIndex(Sym);
return true;
}
case NVPTX::texsurf_handles: {
@@ -1839,15 +1821,14 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
const GlobalValue *GV = TexHandleDef.getOperand(1).getGlobal();
assert(GV->hasName() && "Global sampler must be named!");
InstrsToRemove.insert(&TexHandleDef);
- Idx = MFI->getImageHandleSymbolIndex(GV->getName());
+ Op.ChangeToGA(GV, 0);
return true;
}
case NVPTX::nvvm_move_i64:
case TargetOpcode::COPY: {
- bool Res = findIndexForHandle(TexHandleDef.getOperand(1), MF, Idx);
- if (Res) {
+ bool Res = replaceImageHandle(TexHandleDef.getOperand(1), MF);
+ if (Res)
InstrsToRemove.insert(&TexHandleDef);
- }
return Res;
}
default:
More information about the llvm-commits
mailing list