[llvm] ecdfa36 - Reland "[NVPTX] Cleanup/Refactoring in NVPTX AsmPrinter and RegisterInfo (NFC)" (#127089)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 13 11:35:35 PST 2025


Author: Alex MacLean
Date: 2025-02-13T11:35:31-08:00
New Revision: ecdfa36ecaea7615b244f4cac26a4a023d30a9c1

URL: https://github.com/llvm/llvm-project/commit/ecdfa36ecaea7615b244f4cac26a4a023d30a9c1
DIFF: https://github.com/llvm/llvm-project/commit/ecdfa36ecaea7615b244f4cac26a4a023d30a9c1.diff

LOG: Reland "[NVPTX] Cleanup/Refactoring in NVPTX AsmPrinter and RegisterInfo (NFC)" (#127089)

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
    llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
    llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 75d930d9f7b6f..0538b33530470 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -27,6 +27,7 @@
 #include "cl_common_defines.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallString.h"
@@ -47,6 +48,7 @@
 #include "llvm/CodeGen/TargetRegisterInfo.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
+#include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
@@ -93,20 +95,19 @@ using namespace llvm;
 
 #define DEPOTNAME "__local_depot"
 
-/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
+/// discoverDependentGlobals - Return a set of GlobalVariables on which \p V
 /// depends.
 static void
-DiscoverDependentGlobals(const Value *V,
+discoverDependentGlobals(const Value *V,
                          DenseSet<const GlobalVariable *> &Globals) {
-  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
+  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
     Globals.insert(GV);
-  else {
-    if (const User *U = dyn_cast<User>(V)) {
-      for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
-        DiscoverDependentGlobals(U->getOperand(i), Globals);
-      }
-    }
+    return;
   }
+
+  if (const User *U = dyn_cast<User>(V))
+    for (const auto &O : U->operands())
+      discoverDependentGlobals(O, Globals);
 }
 
 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
@@ -127,8 +128,8 @@ VisitGlobalVariableForEmission(const GlobalVariable *GV,
 
   // Make sure we visit all dependents first
   DenseSet<const GlobalVariable *> Others;
-  for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
-    DiscoverDependentGlobals(GV->getOperand(i), Others);
+  for (const auto &O : GV->operands())
+    discoverDependentGlobals(O, Others);
 
   for (const GlobalVariable *GV : Others)
     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
@@ -623,9 +624,8 @@ static bool usedInGlobalVarDef(const Constant *C) {
   if (!C)
     return false;
 
-  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
+  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C))
     return GV->getName() != "llvm.used";
-  }
 
   for (const User *U : C->users())
     if (const Constant *C = dyn_cast<Constant>(U))
@@ -635,25 +635,23 @@ static bool usedInGlobalVarDef(const Constant *C) {
   return false;
 }
 
-static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
-  if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
-    if (othergv->getName() == "llvm.used")
+static bool usedInOneFunc(const User *U, Function const *&OneFunc) {
+  if (const GlobalVariable *OtherGV = dyn_cast<GlobalVariable>(U))
+    if (OtherGV->getName() == "llvm.used")
       return true;
-  }
 
-  if (const Instruction *instr = dyn_cast<Instruction>(U)) {
-    if (instr->getParent() && instr->getParent()->getParent()) {
-      const Function *curFunc = instr->getParent()->getParent();
-      if (oneFunc && (curFunc != oneFunc))
+  if (const Instruction *I = dyn_cast<Instruction>(U)) {
+    if (const Function *CurFunc = I->getFunction()) {
+      if (OneFunc && (CurFunc != OneFunc))
         return false;
-      oneFunc = curFunc;
+      OneFunc = CurFunc;
       return true;
-    } else
-      return false;
+    }
+    return false;
   }
 
   for (const User *UU : U->users())
-    if (!usedInOneFunc(UU, oneFunc))
+    if (!usedInOneFunc(UU, OneFunc))
       return false;
 
   return true;
@@ -666,16 +664,15 @@ static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
  * 2. Does it have local linkage?
  * 3. Is the global variable referenced only in one function?
  */
-static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
-  if (!gv->hasLocalLinkage())
+static bool canDemoteGlobalVar(const GlobalVariable *GV, Function const *&f) {
+  if (!GV->hasLocalLinkage())
     return false;
-  PointerType *Pty = gv->getType();
-  if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
+  if (GV->getAddressSpace() != ADDRESS_SPACE_SHARED)
     return false;
 
   const Function *oneFunc = nullptr;
 
-  bool flag = usedInOneFunc(gv, oneFunc);
+  bool flag = usedInOneFunc(GV, oneFunc);
   if (!flag)
     return false;
   if (!oneFunc)
@@ -685,27 +682,22 @@ static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
 }
 
 static bool useFuncSeen(const Constant *C,
-                        DenseMap<const Function *, bool> &seenMap) {
+                        const SmallPtrSetImpl<const Function *> &SeenSet) {
   for (const User *U : C->users()) {
     if (const Constant *cu = dyn_cast<Constant>(U)) {
-      if (useFuncSeen(cu, seenMap))
+      if (useFuncSeen(cu, SeenSet))
         return true;
     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
-      const BasicBlock *bb = I->getParent();
-      if (!bb)
-        continue;
-      const Function *caller = bb->getParent();
-      if (!caller)
-        continue;
-      if (seenMap.contains(caller))
-        return true;
+      if (const Function *Caller = I->getFunction())
+        if (SeenSet.contains(Caller))
+          return true;
     }
   }
   return false;
 }
 
 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
-  DenseMap<const Function *, bool> seenMap;
+  SmallPtrSet<const Function *, 32> SeenSet;
   for (const Function &F : M) {
     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
       emitDeclaration(&F, O);
@@ -731,7 +723,7 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
         }
         // Emit a declaration of this function if the function that
         // uses this constant expr has already been seen.
-        if (useFuncSeen(C, seenMap)) {
+        if (useFuncSeen(C, SeenSet)) {
           emitDeclaration(&F, O);
           break;
         }
@@ -739,23 +731,19 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
 
       if (!isa<Instruction>(U))
         continue;
-      const Instruction *instr = cast<Instruction>(U);
-      const BasicBlock *bb = instr->getParent();
-      if (!bb)
-        continue;
-      const Function *caller = bb->getParent();
-      if (!caller)
+      const Function *Caller = cast<Instruction>(U)->getFunction();
+      if (!Caller)
         continue;
 
       // If a caller has already been seen, then the caller is
       // appearing in the module before the callee. so print out
       // a declaration for the callee.
-      if (seenMap.contains(caller)) {
+      if (SeenSet.contains(Caller)) {
         emitDeclaration(&F, O);
         break;
       }
     }
-    seenMap[&F] = true;
+    SeenSet.insert(&F);
   }
   for (const GlobalAlias &GA : M.aliases())
     emitAliasDeclaration(&GA, O);
@@ -818,7 +806,7 @@ void NVPTXAsmPrinter::emitGlobals(const Module &M) {
 
   // Print out module-level global variables in proper order
   for (const GlobalVariable *GV : Globals)
-    printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI);
+    printModuleLevelGV(GV, OS2, /*ProcessDemoted=*/false, STI);
 
   OS2 << '\n';
 
@@ -839,16 +827,14 @@ void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
 
 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
                                  const NVPTXSubtarget &STI) {
-  O << "//\n";
-  O << "// Generated by LLVM NVPTX Back-End\n";
-  O << "//\n";
-  O << "\n";
+  const unsigned PTXVersion = STI.getPTXVersion();
 
-  unsigned PTXVersion = STI.getPTXVersion();
-  O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
-
-  O << ".target ";
-  O << STI.getTargetName();
+  O << "//\n"
+       "// Generated by LLVM NVPTX Back-End\n"
+       "//\n"
+       "\n"
+    << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"
+    << ".target " << STI.getTargetName();
 
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   if (NTM.getDrvInterface() == NVPTX::NVCL)
@@ -871,16 +857,9 @@ void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
   if (HasFullDebugInfo)
     O << ", debug";
 
-  O << "\n";
-
-  O << ".address_size ";
-  if (NTM.is64Bit())
-    O << "64";
-  else
-    O << "32";
-  O << "\n";
-
-  O << "\n";
+  O << "\n"
+    << ".address_size " << (NTM.is64Bit() ? "64" : "32") << "\n"
+    << "\n";
 }
 
 bool NVPTXAsmPrinter::doFinalization(Module &M) {
@@ -928,41 +907,28 @@ void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
                                            raw_ostream &O) {
   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
     if (V->hasExternalLinkage()) {
-      if (isa<GlobalVariable>(V)) {
-        const GlobalVariable *GVar = cast<GlobalVariable>(V);
-        if (GVar) {
-          if (GVar->hasInitializer())
-            O << ".visible ";
-          else
-            O << ".extern ";
-        }
-      } else if (V->isDeclaration())
+      if (const auto *GVar = dyn_cast<GlobalVariable>(V))
+        O << (GVar->hasInitializer() ? ".visible " : ".extern ");
+      else if (V->isDeclaration())
         O << ".extern ";
       else
         O << ".visible ";
     } else if (V->hasAppendingLinkage()) {
-      std::string msg;
-      msg.append("Error: ");
-      msg.append("Symbol ");
-      if (V->hasName())
-        msg.append(std::string(V->getName()));
-      msg.append("has unsupported appending linkage type");
-      llvm_unreachable(msg.c_str());
-    } else if (!V->hasInternalLinkage() &&
-               !V->hasPrivateLinkage()) {
+      report_fatal_error("Symbol '" + (V->hasName() ? V->getName() : "") +
+                         "' has unsupported appending linkage type");
+    } else if (!V->hasInternalLinkage() && !V->hasPrivateLinkage()) {
       O << ".weak ";
     }
   }
 }
 
 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
-                                         raw_ostream &O, bool processDemoted,
+                                         raw_ostream &O, bool ProcessDemoted,
                                          const NVPTXSubtarget &STI) {
   // Skip meta data
-  if (GVar->hasSection()) {
+  if (GVar->hasSection())
     if (GVar->getSection() == "llvm.metadata")
       return;
-  }
 
   // Skip LLVM intrinsic global variables
   if (GVar->getName().starts_with("llvm.") ||
@@ -1069,20 +1035,20 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   }
 
   if (GVar->hasPrivateLinkage()) {
-    if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
+    if (GVar->getName().starts_with("unrollpragma"))
       return;
 
     // FIXME - need better way (e.g. Metadata) to avoid generating this global
-    if (strncmp(GVar->getName().data(), "filename", 8) == 0)
+    if (GVar->getName().starts_with("filename"))
       return;
     if (GVar->use_empty())
       return;
   }
 
-  const Function *demotedFunc = nullptr;
-  if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
+  const Function *DemotedFunc = nullptr;
+  if (!ProcessDemoted && canDemoteGlobalVar(GVar, DemotedFunc)) {
     O << "// " << GVar->getName() << " has been demoted\n";
-    localDecls[demotedFunc].push_back(GVar);
+    localDecls[DemotedFunc].push_back(GVar);
     return;
   }
 
@@ -1090,17 +1056,14 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   emitPTXAddressSpace(GVar->getAddressSpace(), O);
 
   if (isManaged(*GVar)) {
-    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
+    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
       report_fatal_error(
           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
-    }
     O << " .attribute(.managed)";
   }
 
-  if (MaybeAlign A = GVar->getAlign())
-    O << " .align " << A->value();
-  else
-    O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
+  O << " .align "
+    << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
 
   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
@@ -1137,8 +1100,6 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
       }
     }
   } else {
-    uint64_t ElementSize = 0;
-
     // Although PTX has direct support for struct type and array type and
     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
     // targets that support these high level field accesses. Structs, arrays
@@ -1147,8 +1108,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
     case Type::IntegerTyID: // Integers larger than 64 bits
     case Type::StructTyID:
     case Type::ArrayTyID:
-    case Type::FixedVectorTyID:
-      ElementSize = DL.getTypeStoreSize(ETy);
+    case Type::FixedVectorTyID: {
+      const uint64_t ElementSize = DL.getTypeStoreSize(ETy);
       // Ptx allows variable initilization only for constant and
       // global state spaces.
       if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
@@ -1159,7 +1120,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           AggBuffer aggBuffer(ElementSize, *this);
           bufferAggregateConstant(Initializer, &aggBuffer);
           if (aggBuffer.numSymbols()) {
-            unsigned int ptrSize = MAI->getCodePointerSize();
+            const unsigned int ptrSize = MAI->getCodePointerSize();
             if (ElementSize % ptrSize ||
                 !aggBuffer.allSymbolsAligned(ptrSize)) {
               // Print in bytes and use the mask() operator for pointers.
@@ -1190,22 +1151,17 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
         } else {
           O << " .b8 ";
           getSymbol(GVar)->print(O, MAI);
-          if (ElementSize) {
-            O << "[";
-            O << ElementSize;
-            O << "]";
-          }
+          if (ElementSize)
+            O << "[" << ElementSize << "]";
         }
       } else {
         O << " .b8 ";
         getSymbol(GVar)->print(O, MAI);
-        if (ElementSize) {
-          O << "[";
-          O << ElementSize;
-          O << "]";
-        }
+        if (ElementSize)
+          O << "[" << ElementSize << "]";
       }
       break;
+    }
     default:
       llvm_unreachable("type not supported yet");
     }
@@ -1229,7 +1185,7 @@ void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
       Name->print(os, AP.MAI);
     }
   } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
-    const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
+    const MCExpr *Expr = AP.lowerConstantForGV(CExpr, false);
     AP.printMCExpr(*Expr, os);
   } else
     llvm_unreachable("symbol type unknown");
@@ -1298,18 +1254,18 @@ void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
   }
 }
 
-void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
-  auto It = localDecls.find(f);
+void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) {
+  auto It = localDecls.find(F);
   if (It == localDecls.end())
     return;
 
-  std::vector<const GlobalVariable *> &gvars = It->second;
+  ArrayRef<const GlobalVariable *> GVars = It->second;
 
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   const NVPTXSubtarget &STI =
       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
 
-  for (const GlobalVariable *GV : gvars) {
+  for (const GlobalVariable *GV : GVars) {
     O << "\t// demoted variable\n\t";
     printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
   }
@@ -1344,13 +1300,11 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
     if (NumBits == 1)
       return "pred";
-    else if (NumBits <= 64) {
+    if (NumBits <= 64) {
       std::string name = "u";
       return name + utostr(NumBits);
-    } else {
-      llvm_unreachable("Integer too large");
-      break;
     }
+    llvm_unreachable("Integer too large");
     break;
   }
   case Type::BFloatTyID:
@@ -1393,16 +1347,14 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   O << ".";
   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
   if (isManaged(*GVar)) {
-    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
+    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
       report_fatal_error(
           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
-    }
+
     O << " .attribute(.managed)";
   }
-  if (MaybeAlign A = GVar->getAlign())
-    O << " .align " << A->value();
-  else
-    O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
+  O << " .align "
+    << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
 
   // Special case for i128
   if (ETy->isIntegerTy(128)) {
@@ -1413,9 +1365,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   }
 
   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
-    O << " .";
-    O << getPTXFundamentalTypeStr(ETy);
-    O << " ";
+    O << " ." << getPTXFundamentalTypeStr(ETy) << " ";
     getSymbol(GVar)->print(O, MAI);
     return;
   }
@@ -1446,16 +1396,13 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
 
 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
   const DataLayout &DL = getDataLayout();
-  const AttributeList &PAL = F->getAttributes();
   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
   const NVPTXMachineFunctionInfo *MFI =
       MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
 
-  Function::const_arg_iterator I, E;
-  unsigned paramIndex = 0;
-  bool first = true;
-  bool isKernelFunc = isKernelFunction(*F);
+  bool IsFirst = true;
+  const bool IsKernelFunc = isKernelFunction(*F);
 
   if (F->arg_empty() && !F->isVarArg()) {
     O << "()";
@@ -1464,161 +1411,143 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
 
   O << "(\n";
 
-  for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
-    Type *Ty = I->getType();
+  for (const Argument &Arg : F->args()) {
+    Type *Ty = Arg.getType();
+    const std::string ParamSym = TLI->getParamName(F, Arg.getArgNo());
 
-    if (!first)
+    if (!IsFirst)
       O << ",\n";
 
-    first = false;
+    IsFirst = false;
 
     // Handle image/sampler parameters
-    if (isKernelFunc) {
-      if (isSampler(*I) || isImage(*I)) {
-        std::string ParamSym;
-        raw_string_ostream ParamStr(ParamSym);
-        ParamStr << F->getName() << "_param_" << paramIndex;
-        ParamStr.flush();
-        bool EmitImagePtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
-        if (isImage(*I)) {
-          if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
-            if (EmitImagePtr)
-              O << "\t.param .u64 .ptr .surfref ";
-            else
-              O << "\t.param .surfref ";
-            O << TLI->getParamName(F, paramIndex);
-          }
-          else { // Default image is read_only
-            if (EmitImagePtr)
-              O << "\t.param .u64 .ptr .texref ";
-            else
-              O << "\t.param .texref ";
-            O << TLI->getParamName(F, paramIndex);
-          }
-        } else {
-          if (EmitImagePtr)
-            O << "\t.param .u64 .ptr .samplerref ";
-          else
-            O << "\t.param .samplerref ";
-          O << TLI->getParamName(F, paramIndex);
-        }
+    if (IsKernelFunc) {
+      const bool IsSampler = isSampler(Arg);
+      const bool IsTexture = !IsSampler && isImageReadOnly(Arg);
+      const bool IsSurface = !IsSampler && !IsTexture &&
+                             (isImageReadWrite(Arg) || isImageWriteOnly(Arg));
+      if (IsSampler || IsTexture || IsSurface) {
+        const bool EmitImgPtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
+        O << "\t.param ";
+        if (EmitImgPtr)
+          O << ".u64 .ptr ";
+
+        if (IsSampler)
+          O << ".samplerref ";
+        else if (IsTexture)
+          O << ".texref ";
+        else // IsSurface
+          O << ".samplerref ";
+        O << ParamSym;
         continue;
       }
     }
 
-    auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
-                                    paramIndex](Type *Ty) -> Align {
+    auto GetOptimalAlignForParam = [TLI, &DL, F, &Arg](Type *Ty) -> Align {
       if (MaybeAlign StackAlign =
-              getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
+              getAlign(*F, Arg.getArgNo() + AttributeList::FirstArgIndex))
         return StackAlign.value();
 
       Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
-      MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
+      MaybeAlign ParamAlign =
+          Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
       return std::max(TypeAlign, ParamAlign.valueOrOne());
     };
 
-    if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
-      if (ShouldPassAsArray(Ty)) {
-        // Just print .param .align <a> .b8 .param[size];
-        // <a>  = optimal alignment for the element type; always multiple of
-        //        PAL.getParamAlignment
-        // size = typeallocsize of element type
-        Align OptimalAlign = getOptimalAlignForParam(Ty);
+    if (Arg.hasByValAttr()) {
+      // param has byVal attribute.
+      Type *ETy = Arg.getParamByValType();
+      assert(ETy && "Param should have byval type");
+
+      // Print .param .align <a> .b8 .param[size];
+      // <a>  = optimal alignment for the element type; always multiple of
+      //        PAL.getParamAlignment
+      // size = typeallocsize of element type
+      const Align OptimalAlign =
+          IsKernelFunc ? GetOptimalAlignForParam(ETy)
+                       : TLI->getFunctionByValParamAlign(
+                             F, ETy, Arg.getParamAlign().valueOrOne(), DL);
+
+      O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
+        << "[" << DL.getTypeAllocSize(ETy) << "]";
+      continue;
+    }
 
-        O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
-        O << TLI->getParamName(F, paramIndex);
-        O << "[" << DL.getTypeAllocSize(Ty) << "]";
+    if (ShouldPassAsArray(Ty)) {
+      // Just print .param .align <a> .b8 .param[size];
+      // <a>  = optimal alignment for the element type; always multiple of
+      //        PAL.getParamAlignment
+      // size = typeallocsize of element type
+      Align OptimalAlign = GetOptimalAlignForParam(Ty);
 
-        continue;
-      }
-      // Just a scalar
-      auto *PTy = dyn_cast<PointerType>(Ty);
-      unsigned PTySizeInBits = 0;
-      if (PTy) {
-        PTySizeInBits =
-            TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
-        assert(PTySizeInBits && "Invalid pointer size");
-      }
+      O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
+        << "[" << DL.getTypeAllocSize(Ty) << "]";
 
-      if (isKernelFunc) {
-        if (PTy) {
-          O << "\t.param .u" << PTySizeInBits << " .ptr";
-
-          switch (PTy->getAddressSpace()) {
-          default:
-            break;
-          case ADDRESS_SPACE_GLOBAL:
-            O << " .global";
-            break;
-          case ADDRESS_SPACE_SHARED:
-            O << " .shared";
-            break;
-          case ADDRESS_SPACE_CONST:
-            O << " .const";
-            break;
-          case ADDRESS_SPACE_LOCAL:
-            O << " .local";
-            break;
-          }
+      continue;
+    }
+    // Just a scalar
+    auto *PTy = dyn_cast<PointerType>(Ty);
+    unsigned PTySizeInBits = 0;
+    if (PTy) {
+      PTySizeInBits =
+          TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
+      assert(PTySizeInBits && "Invalid pointer size");
+    }
 
-          O << " .align " << I->getParamAlign().valueOrOne().value();
-          O << " " << TLI->getParamName(F, paramIndex);
-          continue;
+    if (IsKernelFunc) {
+      if (PTy) {
+        O << "\t.param .u" << PTySizeInBits << " .ptr";
+
+        switch (PTy->getAddressSpace()) {
+        default:
+          break;
+        case ADDRESS_SPACE_GLOBAL:
+          O << " .global";
+          break;
+        case ADDRESS_SPACE_SHARED:
+          O << " .shared";
+          break;
+        case ADDRESS_SPACE_CONST:
+          O << " .const";
+          break;
+        case ADDRESS_SPACE_LOCAL:
+          O << " .local";
+          break;
         }
 
-        // non-pointer scalar to kernel func
-        O << "\t.param .";
-        // Special case: predicate operands become .u8 types
-        if (Ty->isIntegerTy(1))
-          O << "u8";
-        else
-          O << getPTXFundamentalTypeStr(Ty);
-        O << " ";
-        O << TLI->getParamName(F, paramIndex);
+        O << " .align " << Arg.getParamAlign().valueOrOne().value() << " "
+          << ParamSym;
         continue;
       }
-      // Non-kernel function, just print .param .b<size> for ABI
-      // and .reg .b<size> for non-ABI
-      unsigned sz = 0;
-      if (isa<IntegerType>(Ty)) {
-        sz = cast<IntegerType>(Ty)->getBitWidth();
-        sz = promoteScalarArgumentSize(sz);
-      } else if (PTy) {
-        assert(PTySizeInBits && "Invalid pointer size");
-        sz = PTySizeInBits;
-      } else
-        sz = Ty->getPrimitiveSizeInBits();
-      O << "\t.param .b" << sz << " ";
-      O << TLI->getParamName(F, paramIndex);
+
+      // non-pointer scalar to kernel func
+      O << "\t.param .";
+      // Special case: predicate operands become .u8 types
+      if (Ty->isIntegerTy(1))
+        O << "u8";
+      else
+        O << getPTXFundamentalTypeStr(Ty);
+      O << " " << ParamSym;
       continue;
     }
-
-    // param has byVal attribute.
-    Type *ETy = PAL.getParamByValType(paramIndex);
-    assert(ETy && "Param should have byval type");
-
-    // Print .param .align <a> .b8 .param[size];
-    // <a>  = optimal alignment for the element type; always multiple of
-    //        PAL.getParamAlignment
-    // size = typeallocsize of element type
-    Align OptimalAlign =
-        isKernelFunc
-            ? getOptimalAlignForParam(ETy)
-            : TLI->getFunctionByValParamAlign(
-                  F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
-
-    unsigned sz = DL.getTypeAllocSize(ETy);
-    O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
-    O << TLI->getParamName(F, paramIndex);
-    O << "[" << sz << "]";
+    // Non-kernel function, just print .param .b<size> for ABI
+    // and .reg .b<size> for non-ABI
+    unsigned Size;
+    if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+      Size = promoteScalarArgumentSize(ITy->getBitWidth());
+    } else if (PTy) {
+      assert(PTySizeInBits && "Invalid pointer size");
+      Size = PTySizeInBits;
+    } else
+      Size = Ty->getPrimitiveSizeInBits();
+    O << "\t.param .b" << Size << " " << ParamSym;
   }
 
   if (F->isVarArg()) {
-    if (!first)
+    if (!IsFirst)
       O << ",\n";
-    O << "\t.param .align " << STI.getMaxRequiredAlignment();
-    O << " .b8 ";
-    O << TLI->getParamName(F, /* vararg */ -1) << "[]";
+    O << "\t.param .align " << STI.getMaxRequiredAlignment() << " .b8 "
+      << TLI->getParamName(F, /* vararg */ -1) << "[]";
   }
 
   O << "\n)";
@@ -1641,11 +1570,11 @@ void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
     O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
       << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
     if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
-      O << "\t.reg .b64 \t%SP;\n";
-      O << "\t.reg .b64 \t%SPL;\n";
+      O << "\t.reg .b64 \t%SP;\n"
+        << "\t.reg .b64 \t%SPL;\n";
     } else {
-      O << "\t.reg .b32 \t%SP;\n";
-      O << "\t.reg .b32 \t%SPL;\n";
+      O << "\t.reg .b32 \t%SP;\n"
+        << "\t.reg .b32 \t%SPL;\n";
     }
   }
 
@@ -1662,29 +1591,16 @@ void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
     regmap.insert(std::make_pair(vr, n + 1));
   }
 
-  // Emit register declarations
-  // @TODO: Extract out the real register usage
-  // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
-  // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
-  // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
-  // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
-  // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
-  // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
-  // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
-
   // Emit declaration of the virtual registers or 'physical' registers for
   // each register class
-  for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
-    const TargetRegisterClass *RC = TRI->getRegClass(i);
-    DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
-    std::string rcname = getNVPTXRegClassName(RC);
-    std::string rcStr = getNVPTXRegClassStr(RC);
-    int n = regmap.size();
+  for (const TargetRegisterClass *RC : TRI->regclasses()) {
+    const unsigned N = VRegMapping[RC].size();
 
     // Only declare those registers that may be used.
-    if (n) {
-       O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
-         << ">;\n";
+    if (N) {
+      const StringRef RCName = getNVPTXRegClassName(RC);
+      const StringRef RCStr = getNVPTXRegClassStr(RC);
+      O << "\t.reg " << RCName << " \t" << RCStr << "<" << (N + 1) << ">;\n";
     }
   }
 
@@ -1711,7 +1627,8 @@ void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers(
   }
 }
 
-void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
+void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp,
+                                      raw_ostream &O) const {
   APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
   bool ignored;
   unsigned int numHex;
@@ -1746,10 +1663,7 @@ void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
     return;
   }
   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
-    bool IsNonGenericPointer = false;
-    if (GVar->getType()->getAddressSpace() != 0) {
-      IsNonGenericPointer = true;
-    }
+    const bool IsNonGenericPointer = GVar->getAddressSpace() != 0;
     if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
       O << "generic(";
       getSymbol(GVar)->print(O, MAI);
@@ -1798,7 +1712,7 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
 
   switch (CPV->getType()->getTypeID()) {
   case Type::IntegerTyID:
-    if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
+    if (const auto *CI = dyn_cast<ConstantInt>(CPV)) {
       AddIntToBuffer(CI->getValue());
       break;
     }
@@ -1912,7 +1826,8 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
 /// expressions that are representable in PTX and create
 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
 const MCExpr *
-NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
+NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV,
+                                    bool ProcessingGeneric) const {
   MCContext &Ctx = OutContext;
 
   if (CV->isNullValue() || isa<UndefValue>(CV))
@@ -1922,13 +1837,10 @@ NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric)
     return MCConstantExpr::create(CI->getZExtValue(), Ctx);
 
   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
-    const MCSymbolRefExpr *Expr =
-      MCSymbolRefExpr::create(getSymbol(GV), Ctx);
-    if (ProcessingGeneric) {
+    const MCSymbolRefExpr *Expr = MCSymbolRefExpr::create(getSymbol(GV), Ctx);
+    if (ProcessingGeneric)
       return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
-    } else {
-      return Expr;
-    }
+    return Expr;
   }
 
   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
@@ -2041,7 +1953,7 @@ NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric)
 }
 
 // Copy of MCExpr::print customized for NVPTX
-void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
+void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) const {
   switch (Expr.getKind()) {
   case MCExpr::Target:
     return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);

diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index f58b4bdc40474..f7c3fda332eff 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -101,15 +101,13 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
     // SymbolsBeforeStripping[i].
     SmallVector<const Value *, 4> SymbolsBeforeStripping;
     unsigned curpos;
-    NVPTXAsmPrinter &AP;
-    bool EmitGeneric;
+    const NVPTXAsmPrinter &AP;
+    const bool EmitGeneric;
 
   public:
-    AggBuffer(unsigned size, NVPTXAsmPrinter &AP)
-        : size(size), buffer(size), AP(AP) {
-      curpos = 0;
-      EmitGeneric = AP.EmitGeneric;
-    }
+    AggBuffer(unsigned size, const NVPTXAsmPrinter &AP)
+        : size(size), buffer(size), curpos(0), AP(AP),
+          EmitGeneric(AP.EmitGeneric) {}
 
     // Copy Num bytes from Ptr.
     // if Bytes > Num, zero fill up to Bytes.
@@ -155,7 +153,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
   StringRef getPassName() const override { return "NVPTX Assembly Printer"; }
 
   const Function *F;
-  std::string CurrentFnName;
 
   void emitStartOfAsmFile(Module &M) override;
   void emitBasicBlockStart(const MachineBasicBlock &MBB) override;
@@ -190,8 +187,9 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
   bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNo,
                              const char *ExtraCode, raw_ostream &) override;
 
-  const MCExpr *lowerConstantForGV(const Constant *CV, bool ProcessingGeneric);
-  void printMCExpr(const MCExpr &Expr, raw_ostream &OS);
+  const MCExpr *lowerConstantForGV(const Constant *CV,
+                                   bool ProcessingGeneric) const;
+  void printMCExpr(const MCExpr &Expr, raw_ostream &OS) const;
 
 protected:
   bool doInitialization(Module &M) override;
@@ -217,7 +215,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
   void emitPTXAddressSpace(unsigned int AddressSpace, raw_ostream &O) const;
   std::string getPTXFundamentalTypeStr(Type *Ty, bool = true) const;
   void printScalarConstant(const Constant *CPV, raw_ostream &O);
-  void printFPConstant(const ConstantFP *Fp, raw_ostream &O);
+  void printFPConstant(const ConstantFP *Fp, raw_ostream &O) const;
   void bufferLEByte(const Constant *CPV, int Bytes, AggBuffer *aggBuffer);
   void bufferAggregateConstant(const Constant *CV, AggBuffer *aggBuffer);
 
@@ -245,7 +243,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
   // Since the address value should always be generic in CUDA C and always
   // be specific in OpenCL, we use this simple control here.
   //
-  bool EmitGeneric;
+  const bool EmitGeneric;
 
 public:
   NVPTXAsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer)

diff  --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
index d1b136429d3a4..229c438edf723 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
@@ -24,7 +24,7 @@ using namespace llvm;
 #define DEBUG_TYPE "nvptx-reg-info"
 
 namespace llvm {
-std::string getNVPTXRegClassName(TargetRegisterClass const *RC) {
+StringRef getNVPTXRegClassName(TargetRegisterClass const *RC) {
   if (RC == &NVPTX::Float32RegsRegClass)
     return ".f32";
   if (RC == &NVPTX::Float64RegsRegClass)
@@ -62,7 +62,7 @@ std::string getNVPTXRegClassName(TargetRegisterClass const *RC) {
   return "INTERNAL";
 }
 
-std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) {
+StringRef getNVPTXRegClassStr(TargetRegisterClass const *RC) {
   if (RC == &NVPTX::Float32RegsRegClass)
     return "%f";
   if (RC == &NVPTX::Float64RegsRegClass)
@@ -81,7 +81,7 @@ std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) {
     return "!Special!";
   return "INTERNAL";
 }
-}
+} // namespace llvm
 
 NVPTXRegisterInfo::NVPTXRegisterInfo()
     : NVPTXGenRegisterInfo(0), StrPool(StrAlloc) {}
@@ -144,11 +144,10 @@ void NVPTXRegisterInfo::clearDebugRegisterMap() const {
   debugRegisterMap.clear();
 }
 
-static uint64_t encodeRegisterForDwarf(std::string registerName) {
-  if (registerName.length() > 8) {
+static uint64_t encodeRegisterForDwarf(StringRef RegisterName) {
+  if (RegisterName.size() > 8)
     // The name is more than 8 characters long, and so won't fit into 64 bits.
     return 0;
-  }
 
   // Encode the name string into a DWARF register number using cuda-gdb's
   // encoding.  See cuda_check_dwarf2_reg_ptx_virtual_register in cuda-tdep.c,
@@ -157,14 +156,14 @@ static uint64_t encodeRegisterForDwarf(std::string registerName) {
   // number, which is stored in ULEB128, but in practice must be no more than 8
   // bytes (excluding null terminator, which is not included).
   uint64_t result = 0;
-  for (unsigned char c : registerName)
+  for (unsigned char c : RegisterName)
     result = (result << 8) | c;
   return result;
 }
 
 void NVPTXRegisterInfo::addToDebugRegisterMap(
-    uint64_t preEncodedVirtualRegister, std::string registerName) const {
-  uint64_t mapped = encodeRegisterForDwarf(registerName);
+    uint64_t preEncodedVirtualRegister, StringRef RegisterName) const {
+  uint64_t mapped = encodeRegisterForDwarf(RegisterName);
   if (mapped == 0)
     return;
   debugRegisterMap.insert({preEncodedVirtualRegister, mapped});
@@ -172,13 +171,13 @@ void NVPTXRegisterInfo::addToDebugRegisterMap(
 
 int64_t NVPTXRegisterInfo::getDwarfRegNum(MCRegister RegNum, bool isEH) const {
   if (RegNum.isPhysical()) {
-    std::string name = NVPTXInstPrinter::getRegisterName(RegNum.id());
+    StringRef Name = NVPTXInstPrinter::getRegisterName(RegNum.id());
     // In NVPTXFrameLowering.cpp, we do arrange for %Depot to be accessible from
     // %SP. Using the %Depot register doesn't provide any debug info in
     // cuda-gdb, but switching it to %SP does.
     if (RegNum.id() == NVPTX::VRDepot)
-      name = "%SP";
-    return encodeRegisterForDwarf(name);
+      Name = "%SP";
+    return encodeRegisterForDwarf(Name);
   }
   uint64_t lookup = debugRegisterMap.lookup(RegNum.id());
   if (lookup)

diff  --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h
index d2f6d257d6b07..cfec7377fd634 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h
@@ -69,13 +69,13 @@ class NVPTXRegisterInfo : public NVPTXGenRegisterInfo {
   // here, because the proper encoding for debug registers is available only
   // temporarily during ASM emission.
   void addToDebugRegisterMap(uint64_t preEncodedVirtualRegister,
-                             std::string registerName) const;
+                             StringRef RegisterName) const;
   void clearDebugRegisterMap() const;
   int64_t getDwarfRegNum(MCRegister RegNum, bool isEH) const override;
 };
 
-std::string getNVPTXRegClassName(const TargetRegisterClass *RC);
-std::string getNVPTXRegClassStr(const TargetRegisterClass *RC);
+StringRef getNVPTXRegClassName(const TargetRegisterClass *RC);
+StringRef getNVPTXRegClassStr(const TargetRegisterClass *RC);
 
 } // end namespace llvm
 


        


More information about the llvm-commits mailing list