[llvm] [NVPTX] Use appropriate operands in ReplaceImageHandles (NFC) (PR #127898)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 19 13:51:12 PST 2025
https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/127898
>From 8730bc39da1ea76306fd1c5f5c202bb09a1cf98f Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 19 Feb 2025 21:21:06 +0000
Subject: [PATCH] [NVPTX] Use appropriate operands in ReplaceImageHandles (NFC)
---
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 121 +++---------------
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h | 6 +-
.../Target/NVPTX/NVPTXMachineFunctionInfo.h | 6 -
.../Target/NVPTX/NVPTXReplaceImageHandles.cpp | 51 +++-----
4 files changed, 38 insertions(+), 146 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index c8e29c1da6ec4..6e5dd6b15900c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -149,67 +149,6 @@ void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
EmitToStreamer(*OutStreamer, Inst);
}
-// Handle symbol backtracking for targets that do not support image handles
-bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
- unsigned OpNo, MCOperand &MCOp) {
- const MachineOperand &MO = MI->getOperand(OpNo);
- const MCInstrDesc &MCID = MI->getDesc();
-
- if (MCID.TSFlags & NVPTXII::IsTexFlag) {
- // This is a texture fetch, so operand 4 is a texref and operand 5 is
- // a samplerref
- if (OpNo == 4 && MO.isImm()) {
- lowerImageHandleSymbol(MO.getImm(), MCOp);
- return true;
- }
- if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
- lowerImageHandleSymbol(MO.getImm(), MCOp);
- return true;
- }
-
- return false;
- } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
- unsigned VecSize =
- 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
-
- // For a surface load of vector size N, the Nth operand will be the surfref
- if (OpNo == VecSize && MO.isImm()) {
- lowerImageHandleSymbol(MO.getImm(), MCOp);
- return true;
- }
-
- return false;
- } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
- // This is a surface store, so operand 0 is a surfref
- if (OpNo == 0 && MO.isImm()) {
- lowerImageHandleSymbol(MO.getImm(), MCOp);
- return true;
- }
-
- return false;
- } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
- // This is a query, so operand 1 is a surfref/texref
- if (OpNo == 1 && MO.isImm()) {
- lowerImageHandleSymbol(MO.getImm(), MCOp);
- return true;
- }
-
- return false;
- }
-
- return false;
-}
-
-void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
- // Ewwww
- TargetMachine &TM = const_cast<TargetMachine &>(MF->getTarget());
- NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine &>(TM);
- const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
- StringRef Sym = MFI->getImageHandleSymbol(Index);
- StringRef SymName = nvTM.getStrPool().save(Sym);
- MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
-}
-
void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
OutMI.setOpcode(MI->getOpcode());
// Special: Do not mangle symbol operand of CALL_PROTOTYPE
@@ -220,67 +159,49 @@ void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
return;
}
- for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
- const MachineOperand &MO = MI->getOperand(i);
-
- MCOperand MCOp;
- if (lowerImageHandleOperand(MI, i, MCOp)) {
- OutMI.addOperand(MCOp);
- continue;
- }
-
- if (lowerOperand(MO, MCOp))
- OutMI.addOperand(MCOp);
- }
+ for (const auto MO : MI->operands())
+ OutMI.addOperand(lowerOperand(MO));
}
-bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
- MCOperand &MCOp) {
+MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
switch (MO.getType()) {
- default: llvm_unreachable("unknown operand type");
+ default:
+ llvm_unreachable("unknown operand type");
case MachineOperand::MO_Register:
- MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
- break;
+ return MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
case MachineOperand::MO_Immediate:
- MCOp = MCOperand::createImm(MO.getImm());
- break;
+ return MCOperand::createImm(MO.getImm());
case MachineOperand::MO_MachineBasicBlock:
- MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
- MO.getMBB()->getSymbol(), OutContext));
- break;
+ return MCOperand::createExpr(
+ MCSymbolRefExpr::create(MO.getMBB()->getSymbol(), OutContext));
case MachineOperand::MO_ExternalSymbol:
- MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
- break;
+ return GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
case MachineOperand::MO_GlobalAddress:
- MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
- break;
+ return GetSymbolRef(getSymbol(MO.getGlobal()));
case MachineOperand::MO_FPImmediate: {
const ConstantFP *Cnt = MO.getFPImm();
const APFloat &Val = Cnt->getValueAPF();
switch (Cnt->getType()->getTypeID()) {
- default: report_fatal_error("Unsupported FP type"); break;
- case Type::HalfTyID:
- MCOp = MCOperand::createExpr(
- NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
+ default:
+ report_fatal_error("Unsupported FP type");
break;
+ case Type::HalfTyID:
+ return MCOperand::createExpr(
+ NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
case Type::BFloatTyID:
- MCOp = MCOperand::createExpr(
+ return MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
- break;
case Type::FloatTyID:
- MCOp = MCOperand::createExpr(
- NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
- break;
+ return MCOperand::createExpr(
+ NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
case Type::DoubleTyID:
- MCOp = MCOperand::createExpr(
- NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
- break;
+ return MCOperand::createExpr(
+ NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
}
break;
}
}
- return true;
}
unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index f7c3fda332eff..74daaa2fb7134 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -163,7 +163,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
void emitInstruction(const MachineInstr *) override;
void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI);
- bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp);
+ MCOperand lowerOperand(const MachineOperand &MO);
MCOperand GetSymbolRef(const MCSymbol *Symbol);
unsigned encodeVirtualRegister(unsigned Reg);
@@ -226,10 +226,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
void emitDeclarationWithName(const Function *, MCSymbol *, raw_ostream &O);
void emitDemotedVars(const Function *, raw_ostream &);
- bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo,
- MCOperand &MCOp);
- void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp);
-
bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const;
// Used to control the need to emit .generic() in the initializer of
diff --git a/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h b/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
index 6670cb296f216..d9beab7ec42e1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
@@ -47,12 +47,6 @@ class NVPTXMachineFunctionInfo : public MachineFunctionInfo {
return ImageHandleList.size()-1;
}
- /// Returns the symbol name at the given index.
- StringRef getImageHandleSymbol(unsigned Idx) const {
- assert(ImageHandleList.size() > Idx && "Bad index");
- return ImageHandleList[Idx];
- }
-
/// Check if the symbol has a mapping. Having a mapping means the handle is
/// replaced with a reference
bool checkImageHandleSymbol(StringRef Symbol) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
index a3e3978cbbfe2..4d0694faa0c9a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
@@ -20,7 +20,6 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
-#include "llvm/Support/raw_ostream.h"
using namespace llvm;
@@ -41,10 +40,8 @@ class NVPTXReplaceImageHandles : public MachineFunctionPass {
private:
bool processInstr(MachineInstr &MI);
bool replaceImageHandle(MachineOperand &Op, MachineFunction &MF);
- bool findIndexForHandle(MachineOperand &Op, MachineFunction &MF,
- unsigned &Idx);
};
-}
+} // namespace
char NVPTXReplaceImageHandles::ID = 0;
@@ -1756,9 +1753,11 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
}
return true;
- } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
+ }
+ if (MCID.TSFlags & NVPTXII::IsSuldMask) {
unsigned VecSize =
- 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
+ 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) -
+ 1);
// For a surface load of vector size N, the Nth operand will be the surfref
MachineOperand &SurfHandle = MI.getOperand(VecSize);
@@ -1767,7 +1766,8 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
MI.setDesc(TII->get(suldRegisterToIndexOpcode(MI.getOpcode())));
return true;
- } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
+ }
+ if (MCID.TSFlags & NVPTXII::IsSustFlag) {
// This is a surface store, so operand 0 is a surfref
MachineOperand &SurfHandle = MI.getOperand(0);
@@ -1775,7 +1775,8 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
MI.setDesc(TII->get(sustRegisterToIndexOpcode(MI.getOpcode())));
return true;
- } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
+ }
+ if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
// This is a query, so operand 1 is a surfref/texref
MachineOperand &Handle = MI.getOperand(1);
@@ -1790,16 +1791,6 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
MachineFunction &MF) {
- unsigned Idx;
- if (findIndexForHandle(Op, MF, Idx)) {
- Op.ChangeToImmediate(Idx);
- return true;
- }
- return false;
-}
-
-bool NVPTXReplaceImageHandles::
-findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
const MachineRegisterInfo &MRI = MF.getRegInfo();
NVPTXMachineFunctionInfo *MFI = MF.getInfo<NVPTXMachineFunctionInfo>();
@@ -1812,25 +1803,16 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
case NVPTX::LD_i64_avar: {
// The handle is a parameter value being loaded, replace with the
// parameter symbol
- const NVPTXTargetMachine &TM =
- static_cast<const NVPTXTargetMachine &>(MF.getTarget());
- if (TM.getDrvInterface() == NVPTX::CUDA) {
+ const auto &TM = static_cast<const NVPTXTargetMachine &>(MF.getTarget());
+ if (TM.getDrvInterface() == NVPTX::CUDA)
// For CUDA, we preserve the param loads coming from function arguments
return false;
- }
assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
- std::string ParamBaseName = std::string(MF.getName());
- ParamBaseName += "_param_";
- assert(Sym.starts_with(ParamBaseName) && "Invalid symbol reference");
- unsigned Param = atoi(Sym.data()+ParamBaseName.size());
- std::string NewSym;
- raw_string_ostream NewSymStr(NewSym);
- NewSymStr << MF.getName() << "_param_" << Param;
-
InstrsToRemove.insert(&TexHandleDef);
- Idx = MFI->getImageHandleSymbolIndex(NewSymStr.str());
+ Op.ChangeToES(Sym.data());
+ MFI->getImageHandleSymbolIndex(Sym);
return true;
}
case NVPTX::texsurf_handles: {
@@ -1839,15 +1821,14 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
const GlobalValue *GV = TexHandleDef.getOperand(1).getGlobal();
assert(GV->hasName() && "Global sampler must be named!");
InstrsToRemove.insert(&TexHandleDef);
- Idx = MFI->getImageHandleSymbolIndex(GV->getName());
+ Op.ChangeToGA(GV, 0);
return true;
}
case NVPTX::nvvm_move_i64:
case TargetOpcode::COPY: {
- bool Res = findIndexForHandle(TexHandleDef.getOperand(1), MF, Idx);
- if (Res) {
+ bool Res = replaceImageHandle(TexHandleDef.getOperand(1), MF);
+ if (Res)
InstrsToRemove.insert(&TexHandleDef);
- }
return Res;
}
default:
More information about the llvm-commits
mailing list