[llvm] r177847 - [NVPTX] Fix handling of vector arguments

Justin Holewinski jholewinski at nvidia.com
Sun Mar 24 14:17:48 PDT 2013


Author: jholewinski
Date: Sun Mar 24 16:17:47 2013
New Revision: 177847

URL: http://llvm.org/viewvc/llvm-project?rev=177847&view=rev
Log:
[NVPTX] Fix handling of vector arguments

Added:
    llvm/trunk/test/CodeGen/NVPTX/vector-args.ll
Modified:
    llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.cpp?rev=177847&r1=177846&r2=177847&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXAsmPrinter.cpp Sun Mar 24 16:17:47 2013
@@ -1481,7 +1481,7 @@ void NVPTXAsmPrinter::emitFunctionParamL
   O << "(\n";
 
   for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
-    const Type *Ty = I->getType();
+    Type *Ty = I->getType();
 
     if (!first)
       O << ",\n";
@@ -1504,6 +1504,22 @@ void NVPTXAsmPrinter::emitFunctionParamL
     }
 
     if (PAL.hasAttribute(paramIndex+1, Attribute::ByVal) == false) {
+      if (Ty->isVectorTy()) {
+        // Just print .param .b8 .align <a> .param[size];
+        // <a> = PAL.getparamalignment
+        // size = typeallocsize of element type
+        unsigned align = PAL.getParamAlignment(paramIndex+1);
+        if (align == 0)
+          align = TD->getABITypeAlignment(Ty);
+
+        unsigned sz = TD->getTypeAllocSize(Ty);
+        O << "\t.param .align " << align
+          << " .b8 ";
+        printParamName(I, paramIndex, O);
+        O << "[" << sz << "]";
+
+        continue;
+      }
       // Just a scalar
       const PointerType *PTy = dyn_cast<PointerType>(Ty);
       if (isKernelFunc) {

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=177847&r1=177846&r2=177847&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Sun Mar 24 16:17:47 2013
@@ -1058,15 +1058,15 @@ NVPTXTargetLowering::LowerFormalArgument
     theArgs.push_back(I);
     argTypes.push_back(I->getType());
   }
-  assert(argTypes.size() == Ins.size() &&
-         "Ins types and function types did not match");
+  //assert(argTypes.size() == Ins.size() &&
+  //       "Ins types and function types did not match");
 
   int idx = 0;
-  for (unsigned i=0, e=Ins.size(); i!=e; ++i, ++idx) {
+  for (unsigned i=0, e=argTypes.size(); i!=e; ++i, ++idx) {
     Type *Ty = argTypes[i];
     EVT ObjectVT = getValueType(Ty);
-    assert(ObjectVT == Ins[i].VT &&
-           "Ins type did not match function type");
+    //assert(ObjectVT == Ins[i].VT &&
+    //       "Ins type did not match function type");
 
     // If the kernel argument is image*_t or sampler_t, convert it to
     // a i32 constant holding the parameter position. This can later
@@ -1081,7 +1081,15 @@ NVPTXTargetLowering::LowerFormalArgument
 
     if (theArgs[i]->use_empty()) {
       // argument is dead
-      InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
+      if (ObjectVT.isVector()) {
+        EVT EltVT = ObjectVT.getVectorElementType();
+        unsigned NumElts = ObjectVT.getVectorNumElements();
+        for (unsigned vi = 0; vi < NumElts; ++vi) {
+          InVals.push_back(DAG.getNode(ISD::UNDEF, dl, EltVT));
+        }
+      } else {
+        InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
+      }
       continue;
     }
 
@@ -1090,6 +1098,31 @@ NVPTXTargetLowering::LowerFormalArgument
     // appear in the same order as their order of appearance
     // in the original function. "idx+1" holds that order.
     if (PAL.hasAttribute(i+1, Attribute::ByVal) == false) {
+      if (ObjectVT.isVector()) {
+        unsigned NumElts = ObjectVT.getVectorNumElements();
+        EVT EltVT = ObjectVT.getVectorElementType();
+        unsigned Offset = 0;
+        for (unsigned vi = 0; vi < NumElts; ++vi) {
+          SDValue A = getParamSymbol(DAG, idx, getPointerTy());
+          SDValue B = DAG.getIntPtrConstant(Offset);
+          SDValue Addr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
+                                     //getParamSymbol(DAG, idx, EltVT),
+                                     //DAG.getConstant(Offset, getPointerTy()));
+                                     A, B);
+          Value *SrcValue = Constant::getNullValue(PointerType::get(
+                                            EltVT.getTypeForEVT(F->getContext()),
+                                            llvm::ADDRESS_SPACE_PARAM));
+          SDValue Ld = DAG.getLoad(EltVT, dl, Root, Addr,
+                                   MachinePointerInfo(SrcValue),
+                                   false, false, false,
+                                   TD->getABITypeAlignment(EltVT.getTypeForEVT(
+                                     F->getContext())));
+          Offset += EltVT.getStoreSizeInBits()/8;
+          InVals.push_back(Ld);
+        }
+        continue;
+      }
+
       // A plain scalar.
       if (isABI || isKernel) {
         // If ABI, load from the param symbol

Added: llvm/trunk/test/CodeGen/NVPTX/vector-args.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/vector-args.ll?rev=177847&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/vector-args.ll (added)
+++ llvm/trunk/test/CodeGen/NVPTX/vector-args.ll Sun Mar 24 16:17:47 2013
@@ -0,0 +1,27 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
+
+
+define float @foo(<2 x float> %a) {
+; CHECK: .func (.param .b32 func_retval0) foo
+; CHECK: .param .align 8 .b8 foo_param_0[8]
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+  %t1 = fmul <2 x float> %a, %a
+  %t2 = extractelement <2 x float> %t1, i32 0
+  %t3 = extractelement <2 x float> %t1, i32 1
+  %t4 = fadd float %t2, %t3
+  ret float %t4
+}
+
+
+define float @bar(<4 x float> %a) {
+; CHECK: .func (.param .b32 func_retval0) bar
+; CHECK: .param .align 16 .b8 bar_param_0[16]
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+  %t1 = fmul <4 x float> %a, %a
+  %t2 = extractelement <4 x float> %t1, i32 0
+  %t3 = extractelement <4 x float> %t1, i32 1
+  %t4 = fadd float %t2, %t3
+  ret float %t4
+}





More information about the llvm-commits mailing list