[llvm-commits] [llvm] r144388 - in /llvm/trunk: lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp lib/Target/PTX/PTXAsmPrinter.cpp lib/Target/PTX/PTXAsmPrinter.h lib/Target/PTX/PTXISelLowering.cpp test/CodeGen/PTX/printf.ll

Dan Bailey dan at dneg.com
Fri Nov 11 06:45:12 PST 2011


Author: drb
Date: Fri Nov 11 08:45:12 2011
New Revision: 144388

URL: http://llvm.org/viewvc/llvm-project?rev=144388&view=rev
Log:
allow non-device function calls in PTX when natively handling device-side printf

Added:
    llvm/trunk/test/CodeGen/PTX/printf.ll
Modified:
    llvm/trunk/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp
    llvm/trunk/lib/Target/PTX/PTXAsmPrinter.cpp
    llvm/trunk/lib/Target/PTX/PTXAsmPrinter.h
    llvm/trunk/lib/Target/PTX/PTXISelLowering.cpp

Modified: llvm/trunk/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp?rev=144388&r1=144387&r2=144388&view=diff
==============================================================================
--- llvm/trunk/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp (original)
+++ llvm/trunk/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp Fri Nov 11 08:45:12 2011
@@ -96,9 +96,23 @@
     O << "), ";
   }
 
-  O << *(MI->getOperand(Index++).getExpr()) << ", (";
-
+  const MCExpr* Expr = MI->getOperand(Index++).getExpr();
   unsigned NumArgs = MI->getOperand(Index++).getImm();
+  
+  // if the function call is to printf or puts, change to vprintf
+  if (const MCSymbolRefExpr *SymRefExpr = dyn_cast<MCSymbolRefExpr>(Expr)) {
+    const MCSymbol &Sym = SymRefExpr->getSymbol();
+    if (Sym.getName() == "printf" || Sym.getName() == "puts") {
+      O << "vprintf";
+    } else {
+      O << Sym.getName();
+    }
+  } else {
+    O << *Expr;
+  }
+  
+  O << ", (";
+
   if (NumArgs > 0) {
     printOperand(MI, Index++, O);
     for (unsigned i = 1; i < NumArgs; ++i) {

Modified: llvm/trunk/lib/Target/PTX/PTXAsmPrinter.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/PTX/PTXAsmPrinter.cpp?rev=144388&r1=144387&r2=144388&view=diff
==============================================================================
--- llvm/trunk/lib/Target/PTX/PTXAsmPrinter.cpp (original)
+++ llvm/trunk/lib/Target/PTX/PTXAsmPrinter.cpp Fri Nov 11 08:45:12 2011
@@ -165,6 +165,11 @@
 
   OutStreamer.AddBlankLine();
 
+  // declare external functions
+  for (Module::const_iterator i = M.begin(), e = M.end();
+       i != e; ++i)
+    EmitFunctionDeclaration(i);
+  
   // declare global variables
   for (Module::const_global_iterator i = M.global_begin(), e = M.global_end();
        i != e; ++i)
@@ -454,6 +459,31 @@
   OutStreamer.EmitRawText(os.str());
 }
 
+void PTXAsmPrinter::EmitFunctionDeclaration(const Function* func)
+{
+  const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
+	
+  std::string decl = "";
+
+  // hard-coded emission of extern vprintf function 
+  
+  if (func->getName() == "printf" || func->getName() == "puts") {		
+    decl += ".extern .func (.param .b32 __param_1) vprintf (.param .b";
+    if (ST.is64Bit())	
+      decl += "64";
+    else				
+      decl += "32";
+    decl += " __param_2, .param .b";
+    if (ST.is64Bit())	
+      decl += "64";
+    else				
+      decl += "32";
+    decl += " __param_3)\n";
+  }
+  
+  OutStreamer.EmitRawText(Twine(decl));
+}
+
 unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName,
                                             StringRef DirName) {
   // If FE did not provide a file name, then assume stdin.

Modified: llvm/trunk/lib/Target/PTX/PTXAsmPrinter.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/PTX/PTXAsmPrinter.h?rev=144388&r1=144387&r2=144388&view=diff
==============================================================================
--- llvm/trunk/lib/Target/PTX/PTXAsmPrinter.h (original)
+++ llvm/trunk/lib/Target/PTX/PTXAsmPrinter.h Fri Nov 11 08:45:12 2011
@@ -47,7 +47,7 @@
 
 private:
   void EmitVariableDeclaration(const GlobalVariable *gv);
-  void EmitFunctionDeclaration();
+  void EmitFunctionDeclaration(const Function* func);
 
   StringMap<unsigned> SourceIdMap;
 }; // class PTXAsmPrinter

Modified: llvm/trunk/lib/Target/PTX/PTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/PTX/PTXISelLowering.cpp?rev=144388&r1=144387&r2=144388&view=diff
==============================================================================
--- llvm/trunk/lib/Target/PTX/PTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/PTX/PTXISelLowering.cpp Fri Nov 11 08:45:12 2011
@@ -20,6 +20,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
@@ -352,40 +353,101 @@
                              SmallVectorImpl<SDValue> &InVals) const {
 
   MachineFunction& MF = DAG.getMachineFunction();
-  PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
-  PTXParamManager &PM = MFI->getParamManager();
-
+  PTXMachineFunctionInfo *PTXMFI = MF.getInfo<PTXMachineFunctionInfo>();
+  PTXParamManager &PM = PTXMFI->getParamManager();
+  MachineFrameInfo *MFI = MF.getFrameInfo();
+  
   assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
          "Calls are not handled for the target device");
 
+  // Identify the callee function
+  const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
+  const Function *function = cast<Function>(GV);
+  
+  // allow non-device calls only for printf
+  bool isPrintf = function->getName() == "printf" || function->getName() == "puts";	
+  
+  assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) &&
+			 "PTX function calls must be to PTX device functions");
+  
+  unsigned outSize = isPrintf ? 2 : Outs.size();
+  
   std::vector<SDValue> Ops;
   // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
-  Ops.resize(Outs.size() + Ins.size() + 4);
+  Ops.resize(outSize + Ins.size() + 4);
 
   Ops[0] = Chain;
 
   // Identify the callee function
-  const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
-  assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
-         "PTX function calls must be to PTX device functions");
   Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
   Ops[Ins.size()+2] = Callee;
 
-  // Generate STORE_PARAM nodes for each function argument.  In PTX, function
-  // arguments are explicitly stored into .param variables and passed as
-  // arguments. There is no register/stack-based calling convention in PTX.
-  Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32);
-  for (unsigned i = 0; i != OutVals.size(); ++i) {
-    unsigned Size = OutVals[i].getValueType().getSizeInBits();
-    unsigned Param = PM.addLocalParam(Size);
-    const std::string &ParamName = PM.getParamName(Param);
-    SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
-                                                     MVT::Other);
+  // #Outs
+  Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32);
+  
+  if (isPrintf) {
+    // first argument is the address of the global string variable in memory
+    unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits());
+    SDValue ParamValue0 = DAG.getTargetExternalSymbol(PM.getParamName(Param0).c_str(),
+                                                      MVT::Other);
     Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
-                        ParamValue, OutVals[i]);
-    Ops[i+Ins.size()+4] = ParamValue;
-  }
+                        ParamValue0, OutVals[0]);
+    Ops[Ins.size()+4] = ParamValue0;
+      
+    // alignment is the maximum size of all the arguments
+    unsigned alignment = 0;
+    for (unsigned i = 1; i < OutVals.size(); ++i) {
+      alignment = std::max(alignment, 
+    		               OutVals[i].getValueType().getSizeInBits());
+    }
+
+    // size is the alignment multiplied by the number of arguments
+    unsigned size = alignment * (OutVals.size() - 1);
+  
+    // second argument is the address of the stack object (unless no arguments)
+    unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits());
+    SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(),
+                                                      MVT::Other);
+    Ops[Ins.size()+5] = ParamValue1;
+    
+    if (size > 0)
+    {
+      // create a local stack object to store the arguments
+      unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false);
+      SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy());
+	  
+      // store each of the arguments to the stack in turn
+      for (unsigned int i = 1; i != OutVals.size(); i++) {
+        SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy()));
+        Chain = DAG.getStore(Chain, dl, OutVals[i], FrameAddr,
+                             MachinePointerInfo(),
+                             false, false, 0);
+      }
 
+      // copy the address of the local frame index to get the address in non-local space
+      SDValue genericAddr = DAG.getNode(PTXISD::COPY_ADDRESS, dl, getPointerTy(), FrameIndex);
+
+      // store this address in the second argument
+      Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, ParamValue1, genericAddr);
+    }
+  }
+  else
+  {
+	  // Generate STORE_PARAM nodes for each function argument.  In PTX, function
+	  // arguments are explicitly stored into .param variables and passed as
+	  // arguments. There is no register/stack-based calling convention in PTX.
+	  for (unsigned i = 0; i != OutVals.size(); ++i) {
+		unsigned Size = OutVals[i].getValueType().getSizeInBits();
+		unsigned Param = PM.addLocalParam(Size);
+		const std::string &ParamName = PM.getParamName(Param);
+		SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
+														 MVT::Other);
+		Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
+							ParamValue, OutVals[i]);
+		Ops[i+Ins.size()+4] = ParamValue;
+	  }
+  }
+  
   std::vector<SDValue> InParams;
 
   // Generate list of .param variables to hold the return value(s).

Added: llvm/trunk/test/CodeGen/PTX/printf.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/PTX/printf.ll?rev=144388&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/PTX/printf.ll (added)
+++ llvm/trunk/test/CodeGen/PTX/printf.ll Fri Nov 11 08:45:12 2011
@@ -0,0 +1,25 @@
+; RUN: llc < %s -march=ptx64 -mattr=+ptx20,+sm20 | FileCheck %s
+
+declare i32 @printf(i8*, ...)
+
+ at str = private unnamed_addr constant [6 x i8] c"test\0A\00"
+
+define ptx_device void @t1_printf() {
+; CHECK: mov.u64 %rd{{[0-9]+}}, $L__str;
+; CHECK: call.uni	(__localparam_{{[0-9]+}}), vprintf, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}});
+; CHECK: ret;
+    %1 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([6 x i8]* @str, i64 0, i64 0))
+	ret void
+}
+
+ at str2 = private unnamed_addr constant [11 x i8] c"test = %f\0A\00"
+
+define ptx_device void @t2_printf() {
+; CHECK: .local .align 8 .b8 __local{{[0-9]+}}[{{[0-9]+}}];
+; CHECK: mov.u64 %rd{{[0-9]+}}, $L__str2;
+; CHECK: cvta.local.u64  %rd{{[0-9]+}}, __local{{[0-9+]}};
+; CHECK: call.uni	(__localparam_{{[0-9]+}}), vprintf, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}});
+; CHECK: ret;
+  %1 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([11 x i8]* @str2, i64 0, i64 0), double 0x3FF3333340000000)
+  ret void
+}





More information about the llvm-commits mailing list