[llvm] r203483 - Make sure NVPTX doesn't emit symbol names that aren't valid in PTX.

Eli Bendersky eliben at google.com
Mon Mar 10 13:05:42 PDT 2014


Author: eliben
Date: Mon Mar 10 15:05:42 2014
New Revision: 203483

URL: http://llvm.org/viewvc/llvm-project?rev=203483&view=rev
Log:
Make sure NVPTX doesn't emit symbol names that aren't valid in PTX.

NVPTX, like the other backends, relies on generic symbol name sanitizing done by
MCSymbol. However, the ptxas assembler is more stringent and disallows some
additional characters in symbol names.

See PR19099 for more details.


Modified:
    llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.h

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.cpp?rev=203483&r1=203482&r2=203483&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.cpp Mon Mar 10 15:05:42 2014
@@ -684,7 +684,7 @@ void NVPTXAsmPrinter::emitDeclaration(co
   else
     O << ".func ";
   printReturnValStr(F, O);
-  O << *getSymbol(F) << "\n";
+  O << getSymbolName(F) << "\n";
   emitFunctionParamList(F, O);
   O << ";\n";
 }
@@ -1209,7 +1209,7 @@ void NVPTXAsmPrinter::printModuleLevelGV
     else
       O << getPTXFundamentalTypeStr(ETy, false);
     O << " ";
-    O << *getSymbol(GVar);
+    O << getSymbolName(GVar);
 
     // Ptx allows variable initilization only for constant and global state
     // spaces.
@@ -1245,15 +1245,15 @@ void NVPTXAsmPrinter::printModuleLevelGV
           bufferAggregateConstant(Initializer, &aggBuffer);
           if (aggBuffer.numSymbols) {
             if (nvptxSubtarget.is64Bit()) {
-              O << " .u64 " << *getSymbol(GVar) << "[";
+              O << " .u64 " << getSymbolName(GVar) << "[";
               O << ElementSize / 8;
             } else {
-              O << " .u32 " << *getSymbol(GVar) << "[";
+              O << " .u32 " << getSymbolName(GVar) << "[";
               O << ElementSize / 4;
             }
             O << "]";
           } else {
-            O << " .b8 " << *getSymbol(GVar) << "[";
+            O << " .b8 " << getSymbolName(GVar) << "[";
             O << ElementSize;
             O << "]";
           }
@@ -1261,7 +1261,7 @@ void NVPTXAsmPrinter::printModuleLevelGV
           aggBuffer.print();
           O << "}";
         } else {
-          O << " .b8 " << *getSymbol(GVar);
+          O << " .b8 " << getSymbolName(GVar);
           if (ElementSize) {
             O << "[";
             O << ElementSize;
@@ -1269,7 +1269,7 @@ void NVPTXAsmPrinter::printModuleLevelGV
           }
         }
       } else {
-        O << " .b8 " << *getSymbol(GVar);
+        O << " .b8 " << getSymbolName(GVar);
         if (ElementSize) {
           O << "[";
           O << ElementSize;
@@ -1376,7 +1376,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVaria
     O << " .";
     O << getPTXFundamentalTypeStr(ETy);
     O << " ";
-    O << *getSymbol(GVar);
+    O << getSymbolName(GVar);
     return;
   }
 
@@ -1391,7 +1391,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVaria
   case Type::ArrayTyID:
   case Type::VectorTyID:
     ElementSize = TD->getTypeStoreSize(ETy);
-    O << " .b8 " << *getSymbol(GVar) << "[";
+    O << " .b8 " << getSymbolName(GVar) << "[";
     if (ElementSize) {
       O << itostr(ElementSize);
     }
@@ -1446,7 +1446,7 @@ void NVPTXAsmPrinter::printParamName(Fun
                                      int paramIndex, raw_ostream &O) {
   if ((nvptxSubtarget.getDrvInterface() == NVPTX::NVCL) ||
       (nvptxSubtarget.getDrvInterface() == NVPTX::CUDA))
-    O << *getSymbol(I->getParent()) << "_param_" << paramIndex;
+    O << getSymbolName(I->getParent()) << "_param_" << paramIndex;
   else {
     std::string argName = I->getName();
     const char *p = argName.c_str();
@@ -1505,13 +1505,13 @@ void NVPTXAsmPrinter::emitFunctionParamL
       if (llvm::isImage(*I)) {
         std::string sname = I->getName();
         if (llvm::isImageWriteOnly(*I))
-          O << "\t.param .surfref " << *getSymbol(F) << "_param_"
+          O << "\t.param .surfref " << getSymbolName(F) << "_param_"
             << paramIndex;
         else // Default image is read_only
-          O << "\t.param .texref " << *getSymbol(F) << "_param_"
+          O << "\t.param .texref " << getSymbolName(F) << "_param_"
             << paramIndex;
       } else // Should be llvm::isSampler(*I)
-        O << "\t.param .samplerref " << *getSymbol(F) << "_param_"
+        O << "\t.param .samplerref " << getSymbolName(F) << "_param_"
           << paramIndex;
       continue;
     }
@@ -1758,13 +1758,13 @@ void NVPTXAsmPrinter::printScalarConstan
     return;
   }
   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
-    O << *getSymbol(GVar);
+    O << getSymbolName(GVar);
     return;
   }
   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
     const Value *v = Cexpr->stripPointerCasts();
     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
-      O << *getSymbol(GVar);
+      O << getSymbolName(GVar);
       return;
     } else {
       O << *LowerConstant(CPV, *this);
@@ -2078,7 +2078,7 @@ void NVPTXAsmPrinter::printOperand(const
     break;
 
   case MachineOperand::MO_GlobalAddress:
-    O << *getSymbol(MO.getGlobal());
+    O << getSymbolName(MO.getGlobal());
     break;
 
   case MachineOperand::MO_MachineBasicBlock:
@@ -2139,6 +2139,33 @@ LineReader *NVPTXAsmPrinter::getReader(s
   return reader;
 }
 
+std::string NVPTXAsmPrinter::getSymbolName(const GlobalValue *GV) const {
+  // Obtain the original symbol name.
+  MCSymbol *Sym = getSymbol(GV);
+  std::string OriginalName;
+  raw_string_ostream OriginalNameStream(OriginalName);
+  Sym->print(OriginalNameStream);
+  OriginalNameStream.flush();
+
+  // MCSymbol already does symbol-name sanitizing, so names it produces are
+  // valid for object files. The only two characters valida in that context
+  // and indigestible by the PTX assembler are '.' and '@'.
+  std::string CleanName;
+  raw_string_ostream CleanNameStream(CleanName);
+  for (unsigned I = 0, E = OriginalName.size(); I != E; ++I) {
+    char C = OriginalName[I];
+    if (C == '.') {
+      CleanNameStream << "_$_";
+    } else if (C == '@') {
+      CleanNameStream << "_%_";
+    } else {
+      CleanNameStream << C;
+    }
+  }
+
+  return CleanNameStream.str();
+}
+
 std::string LineReader::readLine(unsigned lineNum) {
   if (lineNum < theCurLine) {
     theCurLine = 0;

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.h?rev=203483&r1=203482&r2=203483&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.h (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.h Mon Mar 10 15:05:42 2014
@@ -276,6 +276,11 @@ private:
 
   LineReader *reader;
   LineReader *getReader(std::string);
+
+  // Get the symbol name of the given global symbol.
+  //
+  // Cleans up the name so it's a valid in PTX assembly.
+  std::string getSymbolName(const GlobalValue *GV) const;
 public:
   NVPTXAsmPrinter(TargetMachine &TM, MCStreamer &Streamer)
       : AsmPrinter(TM, Streamer),





More information about the llvm-commits mailing list