[llvm] 5fe37ff - Revert "[NVPTX] Cleanup/Refactoring in NVPTX AsmPrinter and RegisterInfo (NFC) (#126800)"

Mikhail Goncharov via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 12 02:13:22 PST 2025


Author: Mikhail Goncharov
Date: 2025-02-12T11:13:16+01:00
New Revision: 5fe37ff75ab5cdacd78933726009488068aabca5

URL: https://github.com/llvm/llvm-project/commit/5fe37ff75ab5cdacd78933726009488068aabca5
DIFF: https://github.com/llvm/llvm-project/commit/5fe37ff75ab5cdacd78933726009488068aabca5.diff

LOG: Revert "[NVPTX] Cleanup/Refactoring in NVPTX AsmPrinter and RegisterInfo (NFC) (#126800)"

This reverts commit 215fa9e175c6ef9e2fa92f77fbd4015cd4c99a67.

getNameOrAsOperand is only defined under DEBUG

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
    llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
    llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 5b60151c14cc4..75d930d9f7b6f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -27,7 +27,6 @@
 #include "cl_common_defines.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
-#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallString.h"
@@ -48,7 +47,6 @@
 #include "llvm/CodeGen/TargetRegisterInfo.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
-#include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
@@ -95,19 +93,20 @@ using namespace llvm;
 
 #define DEPOTNAME "__local_depot"
 
-/// discoverDependentGlobals - Return a set of GlobalVariables on which \p V
+/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
 /// depends.
 static void
-discoverDependentGlobals(const Value *V,
+DiscoverDependentGlobals(const Value *V,
                          DenseSet<const GlobalVariable *> &Globals) {
-  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
+  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
     Globals.insert(GV);
-    return;
+  else {
+    if (const User *U = dyn_cast<User>(V)) {
+      for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
+        DiscoverDependentGlobals(U->getOperand(i), Globals);
+      }
+    }
   }
-
-  if (const User *U = dyn_cast<User>(V))
-    for (const auto &O : U->operands())
-      discoverDependentGlobals(O, Globals);
 }
 
 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
@@ -128,8 +127,8 @@ VisitGlobalVariableForEmission(const GlobalVariable *GV,
 
   // Make sure we visit all dependents first
   DenseSet<const GlobalVariable *> Others;
-  for (const auto &O : GV->operands())
-    discoverDependentGlobals(O, Others);
+  for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
+    DiscoverDependentGlobals(GV->getOperand(i), Others);
 
   for (const GlobalVariable *GV : Others)
     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
@@ -624,8 +623,9 @@ static bool usedInGlobalVarDef(const Constant *C) {
   if (!C)
     return false;
 
-  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C))
+  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
     return GV->getName() != "llvm.used";
+  }
 
   for (const User *U : C->users())
     if (const Constant *C = dyn_cast<Constant>(U))
@@ -635,23 +635,25 @@ static bool usedInGlobalVarDef(const Constant *C) {
   return false;
 }
 
-static bool usedInOneFunc(const User *U, Function const *&OneFunc) {
-  if (const GlobalVariable *OtherGV = dyn_cast<GlobalVariable>(U))
-    if (OtherGV->getName() == "llvm.used")
+static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
+  if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
+    if (othergv->getName() == "llvm.used")
       return true;
+  }
 
-  if (const Instruction *I = dyn_cast<Instruction>(U)) {
-    if (const Function *CurFunc = I->getFunction()) {
-      if (OneFunc && (CurFunc != OneFunc))
+  if (const Instruction *instr = dyn_cast<Instruction>(U)) {
+    if (instr->getParent() && instr->getParent()->getParent()) {
+      const Function *curFunc = instr->getParent()->getParent();
+      if (oneFunc && (curFunc != oneFunc))
         return false;
-      OneFunc = CurFunc;
+      oneFunc = curFunc;
       return true;
-    }
-    return false;
+    } else
+      return false;
   }
 
   for (const User *UU : U->users())
-    if (!usedInOneFunc(UU, OneFunc))
+    if (!usedInOneFunc(UU, oneFunc))
       return false;
 
   return true;
@@ -664,15 +666,16 @@ static bool usedInOneFunc(const User *U, Function const *&OneFunc) {
  * 2. Does it have local linkage?
  * 3. Is the global variable referenced only in one function?
  */
-static bool canDemoteGlobalVar(const GlobalVariable *GV, Function const *&f) {
-  if (!GV->hasLocalLinkage())
+static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
+  if (!gv->hasLocalLinkage())
     return false;
-  if (GV->getAddressSpace() != ADDRESS_SPACE_SHARED)
+  PointerType *Pty = gv->getType();
+  if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
     return false;
 
   const Function *oneFunc = nullptr;
 
-  bool flag = usedInOneFunc(GV, oneFunc);
+  bool flag = usedInOneFunc(gv, oneFunc);
   if (!flag)
     return false;
   if (!oneFunc)
@@ -682,22 +685,27 @@ static bool canDemoteGlobalVar(const GlobalVariable *GV, Function const *&f) {
 }
 
 static bool useFuncSeen(const Constant *C,
-                        const SmallPtrSetImpl<const Function *> &SeenSet) {
+                        DenseMap<const Function *, bool> &seenMap) {
   for (const User *U : C->users()) {
     if (const Constant *cu = dyn_cast<Constant>(U)) {
-      if (useFuncSeen(cu, SeenSet))
+      if (useFuncSeen(cu, seenMap))
         return true;
     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
-      if (const Function *Caller = I->getFunction())
-        if (SeenSet.contains(Caller))
-          return true;
+      const BasicBlock *bb = I->getParent();
+      if (!bb)
+        continue;
+      const Function *caller = bb->getParent();
+      if (!caller)
+        continue;
+      if (seenMap.contains(caller))
+        return true;
     }
   }
   return false;
 }
 
 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
-  SmallPtrSet<const Function *, 32> SeenSet;
+  DenseMap<const Function *, bool> seenMap;
   for (const Function &F : M) {
     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
       emitDeclaration(&F, O);
@@ -723,7 +731,7 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
         }
         // Emit a declaration of this function if the function that
         // uses this constant expr has already been seen.
-        if (useFuncSeen(C, SeenSet)) {
+        if (useFuncSeen(C, seenMap)) {
           emitDeclaration(&F, O);
           break;
         }
@@ -731,19 +739,23 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
 
       if (!isa<Instruction>(U))
         continue;
-      const Function *Caller = cast<Instruction>(U)->getFunction();
-      if (!Caller)
+      const Instruction *instr = cast<Instruction>(U);
+      const BasicBlock *bb = instr->getParent();
+      if (!bb)
+        continue;
+      const Function *caller = bb->getParent();
+      if (!caller)
         continue;
 
       // If a caller has already been seen, then the caller is
       // appearing in the module before the callee. so print out
       // a declaration for the callee.
-      if (SeenSet.contains(Caller)) {
+      if (seenMap.contains(caller)) {
         emitDeclaration(&F, O);
         break;
       }
     }
-    SeenSet.insert(&F);
+    seenMap[&F] = true;
   }
   for (const GlobalAlias &GA : M.aliases())
     emitAliasDeclaration(&GA, O);
@@ -806,7 +818,7 @@ void NVPTXAsmPrinter::emitGlobals(const Module &M) {
 
   // Print out module-level global variables in proper order
   for (const GlobalVariable *GV : Globals)
-    printModuleLevelGV(GV, OS2, /*ProcessDemoted=*/false, STI);
+    printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI);
 
   OS2 << '\n';
 
@@ -827,14 +839,16 @@ void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
 
 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
                                  const NVPTXSubtarget &STI) {
-  const unsigned PTXVersion = STI.getPTXVersion();
+  O << "//\n";
+  O << "// Generated by LLVM NVPTX Back-End\n";
+  O << "//\n";
+  O << "\n";
 
-  O << "//\n"
-       "// Generated by LLVM NVPTX Back-End\n"
-       "//\n"
-       "\n"
-    << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"
-    << ".target " << STI.getTargetName();
+  unsigned PTXVersion = STI.getPTXVersion();
+  O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
+
+  O << ".target ";
+  O << STI.getTargetName();
 
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   if (NTM.getDrvInterface() == NVPTX::NVCL)
@@ -857,9 +871,16 @@ void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
   if (HasFullDebugInfo)
     O << ", debug";
 
-  O << "\n"
-    << ".address_size " << (NTM.is64Bit() ? "64" : "32") << "\n"
-    << "\n";
+  O << "\n";
+
+  O << ".address_size ";
+  if (NTM.is64Bit())
+    O << "64";
+  else
+    O << "32";
+  O << "\n";
+
+  O << "\n";
 }
 
 bool NVPTXAsmPrinter::doFinalization(Module &M) {
@@ -907,28 +928,41 @@ void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
                                            raw_ostream &O) {
   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
     if (V->hasExternalLinkage()) {
-      if (const auto *GVar = dyn_cast<GlobalVariable>(V))
-        O << (GVar->hasInitializer() ? ".visible " : ".extern ");
-      else if (V->isDeclaration())
+      if (isa<GlobalVariable>(V)) {
+        const GlobalVariable *GVar = cast<GlobalVariable>(V);
+        if (GVar) {
+          if (GVar->hasInitializer())
+            O << ".visible ";
+          else
+            O << ".extern ";
+        }
+      } else if (V->isDeclaration())
         O << ".extern ";
       else
         O << ".visible ";
     } else if (V->hasAppendingLinkage()) {
-      report_fatal_error("Symbol '" + llvm::Twine(V->getNameOrAsOperand()) +
-                         "' has unsupported appending linkage type");
-    } else if (!V->hasInternalLinkage() && !V->hasPrivateLinkage()) {
+      std::string msg;
+      msg.append("Error: ");
+      msg.append("Symbol ");
+      if (V->hasName())
+        msg.append(std::string(V->getName()));
+      msg.append("has unsupported appending linkage type");
+      llvm_unreachable(msg.c_str());
+    } else if (!V->hasInternalLinkage() &&
+               !V->hasPrivateLinkage()) {
       O << ".weak ";
     }
   }
 }
 
 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
-                                         raw_ostream &O, bool ProcessDemoted,
+                                         raw_ostream &O, bool processDemoted,
                                          const NVPTXSubtarget &STI) {
   // Skip meta data
-  if (GVar->hasSection())
+  if (GVar->hasSection()) {
     if (GVar->getSection() == "llvm.metadata")
       return;
+  }
 
   // Skip LLVM intrinsic global variables
   if (GVar->getName().starts_with("llvm.") ||
@@ -1035,20 +1069,20 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   }
 
   if (GVar->hasPrivateLinkage()) {
-    if (GVar->getName().starts_with("unrollpragma"))
+    if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
       return;
 
     // FIXME - need better way (e.g. Metadata) to avoid generating this global
-    if (GVar->getName().starts_with("filename"))
+    if (strncmp(GVar->getName().data(), "filename", 8) == 0)
       return;
     if (GVar->use_empty())
       return;
   }
 
-  const Function *DemotedFunc = nullptr;
-  if (!ProcessDemoted && canDemoteGlobalVar(GVar, DemotedFunc)) {
+  const Function *demotedFunc = nullptr;
+  if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
     O << "// " << GVar->getName() << " has been demoted\n";
-    localDecls[DemotedFunc].push_back(GVar);
+    localDecls[demotedFunc].push_back(GVar);
     return;
   }
 
@@ -1056,14 +1090,17 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   emitPTXAddressSpace(GVar->getAddressSpace(), O);
 
   if (isManaged(*GVar)) {
-    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
+    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
       report_fatal_error(
           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
+    }
     O << " .attribute(.managed)";
   }
 
-  O << " .align "
-    << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
+  if (MaybeAlign A = GVar->getAlign())
+    O << " .align " << A->value();
+  else
+    O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
 
   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
@@ -1100,6 +1137,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
       }
     }
   } else {
+    uint64_t ElementSize = 0;
+
     // Although PTX has direct support for struct type and array type and
     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
     // targets that support these high level field accesses. Structs, arrays
@@ -1108,8 +1147,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
     case Type::IntegerTyID: // Integers larger than 64 bits
     case Type::StructTyID:
     case Type::ArrayTyID:
-    case Type::FixedVectorTyID: {
-      const uint64_t ElementSize = DL.getTypeStoreSize(ETy);
+    case Type::FixedVectorTyID:
+      ElementSize = DL.getTypeStoreSize(ETy);
       // Ptx allows variable initilization only for constant and
       // global state spaces.
       if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
@@ -1120,7 +1159,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           AggBuffer aggBuffer(ElementSize, *this);
           bufferAggregateConstant(Initializer, &aggBuffer);
           if (aggBuffer.numSymbols()) {
-            const unsigned int ptrSize = MAI->getCodePointerSize();
+            unsigned int ptrSize = MAI->getCodePointerSize();
             if (ElementSize % ptrSize ||
                 !aggBuffer.allSymbolsAligned(ptrSize)) {
               // Print in bytes and use the mask() operator for pointers.
@@ -1151,17 +1190,22 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
         } else {
           O << " .b8 ";
           getSymbol(GVar)->print(O, MAI);
-          if (ElementSize)
-            O << "[" << ElementSize << "]";
+          if (ElementSize) {
+            O << "[";
+            O << ElementSize;
+            O << "]";
+          }
         }
       } else {
         O << " .b8 ";
         getSymbol(GVar)->print(O, MAI);
-        if (ElementSize)
-          O << "[" << ElementSize << "]";
+        if (ElementSize) {
+          O << "[";
+          O << ElementSize;
+          O << "]";
+        }
       }
       break;
-    }
     default:
       llvm_unreachable("type not supported yet");
     }
@@ -1185,7 +1229,7 @@ void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
       Name->print(os, AP.MAI);
     }
   } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
-    const MCExpr *Expr = AP.lowerConstantForGV(CExpr, false);
+    const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
     AP.printMCExpr(*Expr, os);
   } else
     llvm_unreachable("symbol type unknown");
@@ -1254,18 +1298,18 @@ void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
   }
 }
 
-void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) {
-  auto It = localDecls.find(F);
+void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
+  auto It = localDecls.find(f);
   if (It == localDecls.end())
     return;
 
-  ArrayRef<const GlobalVariable *> GVars = It->second;
+  std::vector<const GlobalVariable *> &gvars = It->second;
 
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   const NVPTXSubtarget &STI =
       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
 
-  for (const GlobalVariable *GV : GVars) {
+  for (const GlobalVariable *GV : gvars) {
     O << "\t// demoted variable\n\t";
     printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
   }
@@ -1300,11 +1344,13 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
     if (NumBits == 1)
       return "pred";
-    if (NumBits <= 64) {
+    else if (NumBits <= 64) {
       std::string name = "u";
       return name + utostr(NumBits);
+    } else {
+      llvm_unreachable("Integer too large");
+      break;
     }
-    llvm_unreachable("Integer too large");
     break;
   }
   case Type::BFloatTyID:
@@ -1347,14 +1393,16 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   O << ".";
   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
   if (isManaged(*GVar)) {
-    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
+    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
       report_fatal_error(
           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
-
+    }
     O << " .attribute(.managed)";
   }
-  O << " .align "
-    << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
+  if (MaybeAlign A = GVar->getAlign())
+    O << " .align " << A->value();
+  else
+    O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
 
   // Special case for i128
   if (ETy->isIntegerTy(128)) {
@@ -1365,7 +1413,9 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   }
 
   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
-    O << " ." << getPTXFundamentalTypeStr(ETy) << " ";
+    O << " .";
+    O << getPTXFundamentalTypeStr(ETy);
+    O << " ";
     getSymbol(GVar)->print(O, MAI);
     return;
   }
@@ -1396,13 +1446,16 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
 
 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
   const DataLayout &DL = getDataLayout();
+  const AttributeList &PAL = F->getAttributes();
   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
   const NVPTXMachineFunctionInfo *MFI =
       MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
 
-  bool IsFirst = true;
-  const bool IsKernelFunc = isKernelFunction(*F);
+  Function::const_arg_iterator I, E;
+  unsigned paramIndex = 0;
+  bool first = true;
+  bool isKernelFunc = isKernelFunction(*F);
 
   if (F->arg_empty() && !F->isVarArg()) {
     O << "()";
@@ -1411,143 +1464,161 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
 
   O << "(\n";
 
-  for (const Argument &Arg : F->args()) {
-    Type *Ty = Arg.getType();
-    const std::string ParamSym = TLI->getParamName(F, Arg.getArgNo());
+  for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
+    Type *Ty = I->getType();
 
-    if (!IsFirst)
+    if (!first)
       O << ",\n";
 
-    IsFirst = false;
+    first = false;
 
     // Handle image/sampler parameters
-    if (IsKernelFunc) {
-      const bool IsSampler = isSampler(Arg);
-      const bool IsTexture = !IsSampler && isImageReadOnly(Arg);
-      const bool IsSurface = !IsSampler && !IsTexture &&
-                             (isImageReadWrite(Arg) || isImageWriteOnly(Arg));
-      if (IsSampler || IsTexture || IsSurface) {
-        const bool EmitImgPtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
-        O << "\t.param ";
-        if (EmitImgPtr)
-          O << ".u64 .ptr ";
-
-        if (IsSampler)
-          O << ".samplerref ";
-        else if (IsTexture)
-          O << ".texref ";
-        else // IsSurface
-          O << ".samplerref ";
-        O << ParamSym;
+    if (isKernelFunc) {
+      if (isSampler(*I) || isImage(*I)) {
+        std::string ParamSym;
+        raw_string_ostream ParamStr(ParamSym);
+        ParamStr << F->getName() << "_param_" << paramIndex;
+        ParamStr.flush();
+        bool EmitImagePtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
+        if (isImage(*I)) {
+          if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
+            if (EmitImagePtr)
+              O << "\t.param .u64 .ptr .surfref ";
+            else
+              O << "\t.param .surfref ";
+            O << TLI->getParamName(F, paramIndex);
+          }
+          else { // Default image is read_only
+            if (EmitImagePtr)
+              O << "\t.param .u64 .ptr .texref ";
+            else
+              O << "\t.param .texref ";
+            O << TLI->getParamName(F, paramIndex);
+          }
+        } else {
+          if (EmitImagePtr)
+            O << "\t.param .u64 .ptr .samplerref ";
+          else
+            O << "\t.param .samplerref ";
+          O << TLI->getParamName(F, paramIndex);
+        }
         continue;
       }
     }
 
-    auto GetOptimalAlignForParam = [TLI, &DL, F, &Arg](Type *Ty) -> Align {
+    auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
+                                    paramIndex](Type *Ty) -> Align {
       if (MaybeAlign StackAlign =
-              getAlign(*F, Arg.getArgNo() + AttributeList::FirstArgIndex))
+              getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
         return StackAlign.value();
 
       Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
-      MaybeAlign ParamAlign =
-          Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
+      MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
       return std::max(TypeAlign, ParamAlign.valueOrOne());
     };
 
-    if (Arg.hasByValAttr()) {
-      // param has byVal attribute.
-      Type *ETy = Arg.getParamByValType();
-      assert(ETy && "Param should have byval type");
-
-      // Print .param .align <a> .b8 .param[size];
-      // <a>  = optimal alignment for the element type; always multiple of
-      //        PAL.getParamAlignment
-      // size = typeallocsize of element type
-      const Align OptimalAlign =
-          IsKernelFunc ? GetOptimalAlignForParam(ETy)
-                       : TLI->getFunctionByValParamAlign(
-                             F, ETy, Arg.getParamAlign().valueOrOne(), DL);
-
-      O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
-        << "[" << DL.getTypeAllocSize(ETy) << "]";
-      continue;
-    }
-
-    if (ShouldPassAsArray(Ty)) {
-      // Just print .param .align <a> .b8 .param[size];
-      // <a>  = optimal alignment for the element type; always multiple of
-      //        PAL.getParamAlignment
-      // size = typeallocsize of element type
-      Align OptimalAlign = GetOptimalAlignForParam(Ty);
+    if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
+      if (ShouldPassAsArray(Ty)) {
+        // Just print .param .align <a> .b8 .param[size];
+        // <a>  = optimal alignment for the element type; always multiple of
+        //        PAL.getParamAlignment
+        // size = typeallocsize of element type
+        Align OptimalAlign = getOptimalAlignForParam(Ty);
 
-      O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
-        << "[" << DL.getTypeAllocSize(Ty) << "]";
+        O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
+        O << TLI->getParamName(F, paramIndex);
+        O << "[" << DL.getTypeAllocSize(Ty) << "]";
 
-      continue;
-    }
-    // Just a scalar
-    auto *PTy = dyn_cast<PointerType>(Ty);
-    unsigned PTySizeInBits = 0;
-    if (PTy) {
-      PTySizeInBits =
-          TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
-      assert(PTySizeInBits && "Invalid pointer size");
-    }
-
-    if (IsKernelFunc) {
+        continue;
+      }
+      // Just a scalar
+      auto *PTy = dyn_cast<PointerType>(Ty);
+      unsigned PTySizeInBits = 0;
       if (PTy) {
-        O << "\t.param .u" << PTySizeInBits << " .ptr";
+        PTySizeInBits =
+            TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
+        assert(PTySizeInBits && "Invalid pointer size");
+      }
 
-        switch (PTy->getAddressSpace()) {
-        default:
-          break;
-        case ADDRESS_SPACE_GLOBAL:
-          O << " .global";
-          break;
-        case ADDRESS_SPACE_SHARED:
-          O << " .shared";
-          break;
-        case ADDRESS_SPACE_CONST:
-          O << " .const";
-          break;
-        case ADDRESS_SPACE_LOCAL:
-          O << " .local";
-          break;
+      if (isKernelFunc) {
+        if (PTy) {
+          O << "\t.param .u" << PTySizeInBits << " .ptr";
+
+          switch (PTy->getAddressSpace()) {
+          default:
+            break;
+          case ADDRESS_SPACE_GLOBAL:
+            O << " .global";
+            break;
+          case ADDRESS_SPACE_SHARED:
+            O << " .shared";
+            break;
+          case ADDRESS_SPACE_CONST:
+            O << " .const";
+            break;
+          case ADDRESS_SPACE_LOCAL:
+            O << " .local";
+            break;
+          }
+
+          O << " .align " << I->getParamAlign().valueOrOne().value();
+          O << " " << TLI->getParamName(F, paramIndex);
+          continue;
         }
 
-        O << " .align " << Arg.getParamAlign().valueOrOne().value() << " "
-          << ParamSym;
+        // non-pointer scalar to kernel func
+        O << "\t.param .";
+        // Special case: predicate operands become .u8 types
+        if (Ty->isIntegerTy(1))
+          O << "u8";
+        else
+          O << getPTXFundamentalTypeStr(Ty);
+        O << " ";
+        O << TLI->getParamName(F, paramIndex);
         continue;
       }
-
-      // non-pointer scalar to kernel func
-      O << "\t.param .";
-      // Special case: predicate operands become .u8 types
-      if (Ty->isIntegerTy(1))
-        O << "u8";
-      else
-        O << getPTXFundamentalTypeStr(Ty);
-      O << " " << ParamSym;
+      // Non-kernel function, just print .param .b<size> for ABI
+      // and .reg .b<size> for non-ABI
+      unsigned sz = 0;
+      if (isa<IntegerType>(Ty)) {
+        sz = cast<IntegerType>(Ty)->getBitWidth();
+        sz = promoteScalarArgumentSize(sz);
+      } else if (PTy) {
+        assert(PTySizeInBits && "Invalid pointer size");
+        sz = PTySizeInBits;
+      } else
+        sz = Ty->getPrimitiveSizeInBits();
+      O << "\t.param .b" << sz << " ";
+      O << TLI->getParamName(F, paramIndex);
       continue;
     }
-    // Non-kernel function, just print .param .b<size> for ABI
-    // and .reg .b<size> for non-ABI
-    unsigned Size;
-    if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
-      Size = promoteScalarArgumentSize(ITy->getBitWidth());
-    } else if (PTy) {
-      assert(PTySizeInBits && "Invalid pointer size");
-      Size = PTySizeInBits;
-    } else
-      Size = Ty->getPrimitiveSizeInBits();
-    O << "\t.param .b" << Size << " " << ParamSym;
+
+    // param has byVal attribute.
+    Type *ETy = PAL.getParamByValType(paramIndex);
+    assert(ETy && "Param should have byval type");
+
+    // Print .param .align <a> .b8 .param[size];
+    // <a>  = optimal alignment for the element type; always multiple of
+    //        PAL.getParamAlignment
+    // size = typeallocsize of element type
+    Align OptimalAlign =
+        isKernelFunc
+            ? getOptimalAlignForParam(ETy)
+            : TLI->getFunctionByValParamAlign(
+                  F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
+
+    unsigned sz = DL.getTypeAllocSize(ETy);
+    O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
+    O << TLI->getParamName(F, paramIndex);
+    O << "[" << sz << "]";
   }
 
   if (F->isVarArg()) {
-    if (!IsFirst)
+    if (!first)
       O << ",\n";
-    O << "\t.param .align " << STI.getMaxRequiredAlignment() << " .b8 "
-      << TLI->getParamName(F, /* vararg */ -1) << "[]";
+    O << "\t.param .align " << STI.getMaxRequiredAlignment();
+    O << " .b8 ";
+    O << TLI->getParamName(F, /* vararg */ -1) << "[]";
   }
 
   O << "\n)";
@@ -1570,11 +1641,11 @@ void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
     O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
       << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
     if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
-      O << "\t.reg .b64 \t%SP;\n"
-        << "\t.reg .b64 \t%SPL;\n";
+      O << "\t.reg .b64 \t%SP;\n";
+      O << "\t.reg .b64 \t%SPL;\n";
     } else {
-      O << "\t.reg .b32 \t%SP;\n"
-        << "\t.reg .b32 \t%SPL;\n";
+      O << "\t.reg .b32 \t%SP;\n";
+      O << "\t.reg .b32 \t%SPL;\n";
     }
   }
 
@@ -1591,16 +1662,29 @@ void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
     regmap.insert(std::make_pair(vr, n + 1));
   }
 
+  // Emit register declarations
+  // @TODO: Extract out the real register usage
+  // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
+  // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
+  // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
+  // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
+  // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
+  // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
+  // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
+
   // Emit declaration of the virtual registers or 'physical' registers for
   // each register class
-  for (const TargetRegisterClass *RC : TRI->regclasses()) {
-    const unsigned N = VRegMapping[RC].size();
+  for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
+    const TargetRegisterClass *RC = TRI->getRegClass(i);
+    DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
+    std::string rcname = getNVPTXRegClassName(RC);
+    std::string rcStr = getNVPTXRegClassStr(RC);
+    int n = regmap.size();
 
     // Only declare those registers that may be used.
-    if (N) {
-      const StringRef RCName = getNVPTXRegClassName(RC);
-      const StringRef RCStr = getNVPTXRegClassStr(RC);
-      O << "\t.reg " << RCName << " \t" << RCStr << "<" << (N + 1) << ">;\n";
+    if (n) {
+       O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
+         << ">;\n";
     }
   }
 
@@ -1627,8 +1711,7 @@ void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers(
   }
 }
 
-void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp,
-                                      raw_ostream &O) const {
+void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
   APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
   bool ignored;
   unsigned int numHex;
@@ -1663,7 +1746,10 @@ void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
     return;
   }
   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
-    const bool IsNonGenericPointer = GVar->getAddressSpace() != 0;
+    bool IsNonGenericPointer = false;
+    if (GVar->getType()->getAddressSpace() != 0) {
+      IsNonGenericPointer = true;
+    }
     if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
       O << "generic(";
       getSymbol(GVar)->print(O, MAI);
@@ -1712,7 +1798,7 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
 
   switch (CPV->getType()->getTypeID()) {
   case Type::IntegerTyID:
-    if (const auto *CI = dyn_cast<ConstantInt>(CPV)) {
+    if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
       AddIntToBuffer(CI->getValue());
       break;
     }
@@ -1826,8 +1912,7 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
 /// expressions that are representable in PTX and create
 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
 const MCExpr *
-NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV,
-                                    bool ProcessingGeneric) const {
+NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
   MCContext &Ctx = OutContext;
 
   if (CV->isNullValue() || isa<UndefValue>(CV))
@@ -1837,10 +1922,13 @@ NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV,
     return MCConstantExpr::create(CI->getZExtValue(), Ctx);
 
   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
-    const MCSymbolRefExpr *Expr = MCSymbolRefExpr::create(getSymbol(GV), Ctx);
-    if (ProcessingGeneric)
+    const MCSymbolRefExpr *Expr =
+      MCSymbolRefExpr::create(getSymbol(GV), Ctx);
+    if (ProcessingGeneric) {
       return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
-    return Expr;
+    } else {
+      return Expr;
+    }
   }
 
   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
@@ -1953,7 +2041,7 @@ NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV,
 }
 
 // Copy of MCExpr::print customized for NVPTX
-void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) const {
+void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
   switch (Expr.getKind()) {
   case MCExpr::Target:
     return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);

diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index f7c3fda332eff..f58b4bdc40474 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -101,13 +101,15 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
     // SymbolsBeforeStripping[i].
     SmallVector<const Value *, 4> SymbolsBeforeStripping;
     unsigned curpos;
-    const NVPTXAsmPrinter &AP;
-    const bool EmitGeneric;
+    NVPTXAsmPrinter &AP;
+    bool EmitGeneric;
 
   public:
-    AggBuffer(unsigned size, const NVPTXAsmPrinter &AP)
-        : size(size), buffer(size), curpos(0), AP(AP),
-          EmitGeneric(AP.EmitGeneric) {}
+    AggBuffer(unsigned size, NVPTXAsmPrinter &AP)
+        : size(size), buffer(size), AP(AP) {
+      curpos = 0;
+      EmitGeneric = AP.EmitGeneric;
+    }
 
     // Copy Num bytes from Ptr.
     // if Bytes > Num, zero fill up to Bytes.
@@ -153,6 +155,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
   StringRef getPassName() const override { return "NVPTX Assembly Printer"; }
 
   const Function *F;
+  std::string CurrentFnName;
 
   void emitStartOfAsmFile(Module &M) override;
   void emitBasicBlockStart(const MachineBasicBlock &MBB) override;
@@ -187,9 +190,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
   bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNo,
                              const char *ExtraCode, raw_ostream &) override;
 
-  const MCExpr *lowerConstantForGV(const Constant *CV,
-                                   bool ProcessingGeneric) const;
-  void printMCExpr(const MCExpr &Expr, raw_ostream &OS) const;
+  const MCExpr *lowerConstantForGV(const Constant *CV, bool ProcessingGeneric);
+  void printMCExpr(const MCExpr &Expr, raw_ostream &OS);
 
 protected:
   bool doInitialization(Module &M) override;
@@ -215,7 +217,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
   void emitPTXAddressSpace(unsigned int AddressSpace, raw_ostream &O) const;
   std::string getPTXFundamentalTypeStr(Type *Ty, bool = true) const;
   void printScalarConstant(const Constant *CPV, raw_ostream &O);
-  void printFPConstant(const ConstantFP *Fp, raw_ostream &O) const;
+  void printFPConstant(const ConstantFP *Fp, raw_ostream &O);
   void bufferLEByte(const Constant *CPV, int Bytes, AggBuffer *aggBuffer);
   void bufferAggregateConstant(const Constant *CV, AggBuffer *aggBuffer);
 
@@ -243,7 +245,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
   // Since the address value should always be generic in CUDA C and always
   // be specific in OpenCL, we use this simple control here.
   //
-  const bool EmitGeneric;
+  bool EmitGeneric;
 
 public:
   NVPTXAsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer)

diff  --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
index 229c438edf723..d1b136429d3a4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
@@ -24,7 +24,7 @@ using namespace llvm;
 #define DEBUG_TYPE "nvptx-reg-info"
 
 namespace llvm {
-StringRef getNVPTXRegClassName(TargetRegisterClass const *RC) {
+std::string getNVPTXRegClassName(TargetRegisterClass const *RC) {
   if (RC == &NVPTX::Float32RegsRegClass)
     return ".f32";
   if (RC == &NVPTX::Float64RegsRegClass)
@@ -62,7 +62,7 @@ StringRef getNVPTXRegClassName(TargetRegisterClass const *RC) {
   return "INTERNAL";
 }
 
-StringRef getNVPTXRegClassStr(TargetRegisterClass const *RC) {
+std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) {
   if (RC == &NVPTX::Float32RegsRegClass)
     return "%f";
   if (RC == &NVPTX::Float64RegsRegClass)
@@ -81,7 +81,7 @@ StringRef getNVPTXRegClassStr(TargetRegisterClass const *RC) {
     return "!Special!";
   return "INTERNAL";
 }
-} // namespace llvm
+}
 
 NVPTXRegisterInfo::NVPTXRegisterInfo()
     : NVPTXGenRegisterInfo(0), StrPool(StrAlloc) {}
@@ -144,10 +144,11 @@ void NVPTXRegisterInfo::clearDebugRegisterMap() const {
   debugRegisterMap.clear();
 }
 
-static uint64_t encodeRegisterForDwarf(StringRef RegisterName) {
-  if (RegisterName.size() > 8)
+static uint64_t encodeRegisterForDwarf(std::string registerName) {
+  if (registerName.length() > 8) {
     // The name is more than 8 characters long, and so won't fit into 64 bits.
     return 0;
+  }
 
   // Encode the name string into a DWARF register number using cuda-gdb's
   // encoding.  See cuda_check_dwarf2_reg_ptx_virtual_register in cuda-tdep.c,
@@ -156,14 +157,14 @@ static uint64_t encodeRegisterForDwarf(StringRef RegisterName) {
   // number, which is stored in ULEB128, but in practice must be no more than 8
   // bytes (excluding null terminator, which is not included).
   uint64_t result = 0;
-  for (unsigned char c : RegisterName)
+  for (unsigned char c : registerName)
     result = (result << 8) | c;
   return result;
 }
 
 void NVPTXRegisterInfo::addToDebugRegisterMap(
-    uint64_t preEncodedVirtualRegister, StringRef RegisterName) const {
-  uint64_t mapped = encodeRegisterForDwarf(RegisterName);
+    uint64_t preEncodedVirtualRegister, std::string registerName) const {
+  uint64_t mapped = encodeRegisterForDwarf(registerName);
   if (mapped == 0)
     return;
   debugRegisterMap.insert({preEncodedVirtualRegister, mapped});
@@ -171,13 +172,13 @@ void NVPTXRegisterInfo::addToDebugRegisterMap(
 
 int64_t NVPTXRegisterInfo::getDwarfRegNum(MCRegister RegNum, bool isEH) const {
   if (RegNum.isPhysical()) {
-    StringRef Name = NVPTXInstPrinter::getRegisterName(RegNum.id());
+    std::string name = NVPTXInstPrinter::getRegisterName(RegNum.id());
     // In NVPTXFrameLowering.cpp, we do arrange for %Depot to be accessible from
     // %SP. Using the %Depot register doesn't provide any debug info in
     // cuda-gdb, but switching it to %SP does.
     if (RegNum.id() == NVPTX::VRDepot)
-      Name = "%SP";
-    return encodeRegisterForDwarf(Name);
+      name = "%SP";
+    return encodeRegisterForDwarf(name);
   }
   uint64_t lookup = debugRegisterMap.lookup(RegNum.id());
   if (lookup)

diff  --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h
index cfec7377fd634..d2f6d257d6b07 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h
@@ -69,13 +69,13 @@ class NVPTXRegisterInfo : public NVPTXGenRegisterInfo {
   // here, because the proper encoding for debug registers is available only
   // temporarily during ASM emission.
   void addToDebugRegisterMap(uint64_t preEncodedVirtualRegister,
-                             StringRef RegisterName) const;
+                             std::string registerName) const;
   void clearDebugRegisterMap() const;
   int64_t getDwarfRegNum(MCRegister RegNum, bool isEH) const override;
 };
 
-StringRef getNVPTXRegClassName(const TargetRegisterClass *RC);
-StringRef getNVPTXRegClassStr(const TargetRegisterClass *RC);
+std::string getNVPTXRegClassName(const TargetRegisterClass *RC);
+std::string getNVPTXRegClassStr(const TargetRegisterClass *RC);
 
 } // end namespace llvm
 


        


More information about the llvm-commits mailing list