[llvm] 5bf86d9 - [NVPTX] Remove code duplication in LowerCall

Daniil Kovalev via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 25 02:43:09 PDT 2022


Author: Daniil Kovalev
Date: 2022-03-25T12:36:20+03:00
New Revision: 5bf86d9e88fa841f5f50f4b8e3b337191691a45d

URL: https://github.com/llvm/llvm-project/commit/5bf86d9e88fa841f5f50f4b8e3b337191691a45d
DIFF: https://github.com/llvm/llvm-project/commit/5bf86d9e88fa841f5f50f4b8e3b337191691a45d.diff

LOG: [NVPTX] Remove code duplication in LowerCall

In D120129 we enhanced vectorization options of byval parameters. This patch
removes code duplication when handling byval and non-byval cases.

Differential Revision: https://reviews.llvm.org/D122381

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 382e83dbb4cb9..11fc25722fcd6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1441,11 +1441,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     return Chain;
 
   unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1);
-  SDValue tempChain = Chain;
+  SDValue TempChain = Chain;
   Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
   SDValue InFlag = Chain.getValue(1);
 
-  unsigned paramCount = 0;
+  unsigned ParamCount = 0;
   // Args.size() and Outs.size() need not match.
   // Outs.size() will be larger
   //   * if there is an aggregate argument with multiple fields (each field
@@ -1461,185 +1461,115 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
     EVT VT = Outs[OIdx].VT;
     Type *Ty = Args[i].Ty;
+    bool IsByVal = Outs[OIdx].Flags.isByVal();
 
-    if (!Outs[OIdx].Flags.isByVal()) {
-      SmallVector<EVT, 16> VTs;
-      SmallVector<uint64_t, 16> Offsets;
-      ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets);
-      Align ArgAlign = getArgumentAlignment(Callee, CB, Ty, paramCount + 1, DL);
-      unsigned AllocSize = DL.getTypeAllocSize(Ty);
-      SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-      bool NeedAlign; // Does argument declaration specify alignment?
-      if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
-        // declare .param .align <align> .b8 .param<n>[<size>];
-        SDValue DeclareParamOps[] = {
-            Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
-            DAG.getConstant(paramCount, dl, MVT::i32),
-            DAG.getConstant(AllocSize, dl, MVT::i32), InFlag};
-        Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
-                            DeclareParamOps);
-        NeedAlign = true;
-      } else {
-        // declare .param .b<size> .param<n>;
-        if ((VT.isInteger() || VT.isFloatingPoint()) && AllocSize < 4) {
-          // PTX ABI requires integral types to be at least 32 bits in
-          // size. FP16 is loaded/stored using i16, so it's handled
-          // here as well.
-          AllocSize = 4;
-        }
-        SDValue DeclareScalarParamOps[] = {
-            Chain, DAG.getConstant(paramCount, dl, MVT::i32),
-            DAG.getConstant(AllocSize * 8, dl, MVT::i32),
-            DAG.getConstant(0, dl, MVT::i32), InFlag};
-        Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
-                            DeclareScalarParamOps);
-        NeedAlign = false;
-      }
-      InFlag = Chain.getValue(1);
-
-      // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
-      // than 32-bits are sign extended or zero extended, depending on
-      // whether they are signed or unsigned types. This case applies
-      // only to scalar parameters and not to aggregate values.
-      bool ExtendIntegerParam =
-          Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
-
-      auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
-      SmallVector<SDValue, 6> StoreOperands;
-      for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
-        // New store.
-        if (VectorInfo[j] & PVF_FIRST) {
-          assert(StoreOperands.empty() && "Unfinished preceding store.");
-          StoreOperands.push_back(Chain);
-          StoreOperands.push_back(DAG.getConstant(paramCount, dl, MVT::i32));
-          StoreOperands.push_back(DAG.getConstant(Offsets[j], dl, MVT::i32));
-        }
-
-        EVT EltVT = VTs[j];
-        SDValue StVal = OutVals[OIdx];
-        if (ExtendIntegerParam) {
-          assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
-          // zext/sext to i32
-          StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
-                                                        : ISD::ZERO_EXTEND,
-                              dl, MVT::i32, StVal);
-        } else if (EltVT.getSizeInBits() < 16) {
-          // Use 16-bit registers for small stores as it's the
-          // smallest general purpose register size supported by NVPTX.
-          StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
-        }
-
-        // Record the value to store.
-        StoreOperands.push_back(StVal);
-
-        if (VectorInfo[j] & PVF_LAST) {
-          unsigned NumElts = StoreOperands.size() - 3;
-          NVPTXISD::NodeType Op;
-          switch (NumElts) {
-          case 1:
-            Op = NVPTXISD::StoreParam;
-            break;
-          case 2:
-            Op = NVPTXISD::StoreParamV2;
-            break;
-          case 4:
-            Op = NVPTXISD::StoreParamV4;
-            break;
-          default:
-            llvm_unreachable("Invalid vector info.");
-          }
-
-          StoreOperands.push_back(InFlag);
-
-          // Adjust type of the store op if we've extended the scalar
-          // return value.
-          EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : VTs[j];
-          MaybeAlign EltAlign;
-          if (NeedAlign)
-            EltAlign = commonAlignment(ArgAlign, Offsets[j]);
-
-          Chain = DAG.getMemIntrinsicNode(
-              Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
-              TheStoreType, MachinePointerInfo(), EltAlign,
-              MachineMemOperand::MOStore);
-          InFlag = Chain.getValue(1);
-
-          // Cleanup.
-          StoreOperands.clear();
-        }
-        ++OIdx;
-      }
-      assert(StoreOperands.empty() && "Unfinished parameter store.");
-      if (VTs.size() > 0)
-        --OIdx;
-      ++paramCount;
-      continue;
-    }
-
-    // ByVal arguments
-    // TODO: remove code duplication when handling byval and non-byval cases.
     SmallVector<EVT, 16> VTs;
     SmallVector<uint64_t, 16> Offsets;
-    Type *ETy = Args[i].IndirectType;
-    assert(ETy && "byval arg must have indirect type");
-    ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, 0);
 
-    // declare .param .align <align> .b8 .param<n>[<size>];
-    unsigned sz = Outs[OIdx].Flags.getByValSize();
-    SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+    assert((!IsByVal || Args[i].IndirectType) &&
+           "byval arg must have indirect type");
+    Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
+    ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets);
+
+    Align ArgAlign;
+    if (IsByVal) {
+      // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
+      // so we don't need to worry whether it's naturally aligned or not.
+      // See TargetLowering::LowerCallTo().
+      ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+
+      // Try to increase alignment to enhance vectorization options.
+      ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(
+                                        CB->getCalledFunction(), ETy, DL));
+
+      // Enforce minumum alignment of 4 to work around ptxas miscompile
+      // for sm_50+. See corresponding alignment adjustment in
+      // emitFunctionParamList() for details.
+      ArgAlign = std::max(ArgAlign, Align(4));
+    } else {
+      ArgAlign = getArgumentAlignment(Callee, CB, Ty, ParamCount + 1, DL);
+    }
 
-    // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
-    // so we don't need to worry about natural alignment or not.
-    // See TargetLowering::LowerCallTo().
-    Align ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+    unsigned TypeSize =
+        (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
+    SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
 
-    // Try to increase alignment to enhance vectorization options.
-    const Function *F = CB->getCalledFunction();
-    Align AlignCandidate = getFunctionParamOptimizedAlign(F, ETy, DL);
-    ArgAlign = std::max(ArgAlign, AlignCandidate);
-
-    // Enforce minumum alignment of 4 to work around ptxas miscompile
-    // for sm_50+. See corresponding alignment adjustment in
-    // emitFunctionParamList() for details.
-    if (ArgAlign < Align(4))
-      ArgAlign = Align(4);
-    SDValue DeclareParamOps[] = {
-        Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
-        DAG.getConstant(paramCount, dl, MVT::i32),
-        DAG.getConstant(sz, dl, MVT::i32), InFlag};
-    Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
-                        DeclareParamOps);
+    bool NeedAlign; // Does argument declaration specify alignment?
+    if (IsByVal ||
+        (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128))) {
+      // declare .param .align <align> .b8 .param<n>[<size>];
+      SDValue DeclareParamOps[] = {
+          Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
+          DAG.getConstant(ParamCount, dl, MVT::i32),
+          DAG.getConstant(TypeSize, dl, MVT::i32), InFlag};
+      Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
+                          DeclareParamOps);
+      NeedAlign = true;
+    } else {
+      // declare .param .b<size> .param<n>;
+      if ((VT.isInteger() || VT.isFloatingPoint()) && TypeSize < 4) {
+        // PTX ABI requires integral types to be at least 32 bits in
+        // size. FP16 is loaded/stored using i16, so it's handled
+        // here as well.
+        TypeSize = 4;
+      }
+      SDValue DeclareScalarParamOps[] = {
+          Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
+          DAG.getConstant(TypeSize * 8, dl, MVT::i32),
+          DAG.getConstant(0, dl, MVT::i32), InFlag};
+      Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
+                          DeclareScalarParamOps);
+      NeedAlign = false;
+    }
     InFlag = Chain.getValue(1);
 
+    // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
+    // than 32-bits are sign extended or zero extended, depending on
+    // whether they are signed or unsigned types. This case applies
+    // only to scalar parameters and not to aggregate values.
+    bool ExtendIntegerParam =
+        Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
+
     auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
     SmallVector<SDValue, 6> StoreOperands;
     for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
-      EVT elemtype = VTs[j];
-      int curOffset = Offsets[j];
-      Align PartAlign = commonAlignment(ArgAlign, curOffset);
+      EVT EltVT = VTs[j];
+      int CurOffset = Offsets[j];
+      MaybeAlign PartAlign;
+      if (NeedAlign)
+        PartAlign = commonAlignment(ArgAlign, CurOffset);
 
       // New store.
       if (VectorInfo[j] & PVF_FIRST) {
         assert(StoreOperands.empty() && "Unfinished preceding store.");
         StoreOperands.push_back(Chain);
-        StoreOperands.push_back(DAG.getConstant(paramCount, dl, MVT::i32));
-        StoreOperands.push_back(DAG.getConstant(curOffset, dl, MVT::i32));
+        StoreOperands.push_back(DAG.getConstant(ParamCount, dl, MVT::i32));
+        StoreOperands.push_back(DAG.getConstant(CurOffset, dl, MVT::i32));
       }
 
-      auto PtrVT = getPointerTy(DL);
-      SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, OutVals[OIdx],
-                                    DAG.getConstant(curOffset, dl, PtrVT));
-      SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
-                                   MachinePointerInfo(), PartAlign);
+      SDValue StVal = OutVals[OIdx];
+      if (IsByVal) {
+        auto PtrVT = getPointerTy(DL);
+        SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
+                                      DAG.getConstant(CurOffset, dl, PtrVT));
+        StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
+                            PartAlign);
+      } else if (ExtendIntegerParam) {
+        assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
+        // zext/sext to i32
+        StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
+                                                      : ISD::ZERO_EXTEND,
+                            dl, MVT::i32, StVal);
+      }
 
-      if (elemtype.getSizeInBits() < 16) {
+      if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) {
         // Use 16-bit registers for small stores as it's the
         // smallest general purpose register size supported by NVPTX.
-        theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
+        StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
       }
 
       // Record the value to store.
-      StoreOperands.push_back(theVal);
+      StoreOperands.push_back(StVal);
 
       if (VectorInfo[j] & PVF_LAST) {
         unsigned NumElts = StoreOperands.size() - 3;
@@ -1660,18 +1590,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
         StoreOperands.push_back(InFlag);
 
+        // Adjust type of the store op if we've extended the scalar
+        // return value.
+        EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+
         Chain = DAG.getMemIntrinsicNode(
             Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
-            elemtype, MachinePointerInfo(), PartAlign,
+            TheStoreType, MachinePointerInfo(), PartAlign,
             MachineMemOperand::MOStore);
         InFlag = Chain.getValue(1);
 
         // Cleanup.
         StoreOperands.clear();
       }
+      if (!IsByVal)
+        ++OIdx;
     }
     assert(StoreOperands.empty() && "Unfinished parameter store.");
-    ++paramCount;
+    if (!IsByVal && VTs.size() > 0)
+      --OIdx;
+    ++ParamCount;
   }
 
   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
@@ -1778,7 +1716,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                       CallArgBeginOps);
   InFlag = Chain.getValue(1);
 
-  for (unsigned i = 0, e = paramCount; i != e; ++i) {
+  for (unsigned i = 0, e = ParamCount; i != e; ++i) {
     unsigned opcode;
     if (i == (e - 1))
       opcode = NVPTXISD::LastCallArg;


        


More information about the llvm-commits mailing list