[PATCH] D30011: [NVPTX] Unify vectorization of load/stores of aggregate arguments and return values.

Justin Lebar via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 16 17:31:50 PST 2017


jlebar added inline comments.


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:190
+// we can process in a single ld/st operation (1 for scalar, 2 or 4
+// for vectors).
+static unsigned CanVectorizeAt(unsigned Idx, uint32_t AccessSize,
----------------
All right, now that I understand this, it's doing something deceptively simple.  Perhaps:

> Check whether we can merge loads/stores of some of the pieces of a flattened function parameter or return value into a single vector load/store.
>
> The flattened parameter is represented as a list of EVTs and offsets, and the whole structure is aligned to ParamAlignment.  This function determines whether we can load/store pieces of the parameter starting at index Idx using a single vectorized op of size AccessSize.  If so, it returns the number of param pieces covered by the vector op.  Otherwise, it returns 1.

Perhaps "CanVectorizeAt" is not a great name.  I think we want "Param" in the name.  Maybe `CanMergeParamLoadStoresStartingAt`?  Ideally we'd also pick a name that doesn't sound like this returns a boolean, but I don't have a suggestion.


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:208
+  // Element is too large to vectorize.
+  if (EltSize >= AccessSize)
+    return 1;
----------------
Actually I'm not sure why we bail here if the sizes are equal.  Don't we want to vectorize {i32, i32}, even if AccessSize is 4 bytes?


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:217
+  // Can't vectorize if AccessBytes if not a multiple of EltSize.
+  if (AccessSize != EltSize * NumElts)
+    return 1;
----------------
Can we move this right below the NumElts computation?  Long head-scratcher trying to figure out how we knew that the division above is exact.  :)


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:220
+
+  // PTX ISA can only deal with 2 and 4 element vector ops.
+  if (NumElts != 4 && NumElts != 2)
----------------
Nit, "2- and 4-element vector ops"


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:224
+
+  for (unsigned j = 1; j < NumElts; ++j) {
+    // Types do not match.
----------------
Maybe it would make sense to write this as

  for (unsigned j = Idx + 1; j < Idx + NumElts; ++j)  { ... }


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:237
+
+enum PtxVectorInfo {
+  PTX_LDST_VECTORIZED = 0x0, // Middle elements of a vector.
----------------
Please add a brief comment.  I'm also not sure about the name, for the same reasons as CanVectorizeAt -- it's very generic.  When we change the enumeration name, we should probably also change PTX_LDST to something else.


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:248
+// alignment \p ParamAlignment, returns \p VectorInfo with each
+// element indicating whether its load/store can be vectorized.
+static void VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
----------------
Perhaps

> Computes whether and how we can vectorize the loads/stores of a flattened function parameter or return value.
>
> The flattened parameter is represented as the list of ValueVTs and Offsets, and is aligned to ParamAlignment bytes.  We return a vector of the same size as ValueVTs indicating how each piece should be loaded/stored (i.e. as a scalar, or as part of a vector load/store).


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:252
+                                 unsigned ParamAlignment,
+                                 SmallVectorImpl<PtxVectorInfo> &VectorInfo) {
+  // Set vector size to match ValueVTs and mark all elements as
----------------
Can we return the SmallVector instead of using an outparam?  We'll get RVO, so it's the same to the compiler.


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:1417
     if (!Outs[OIdx].Flags.isByVal()) {
-      if (Ty->isAggregateType()) {
-        // aggregate
-        SmallVector<EVT, 16> vtparts;
-        SmallVector<uint64_t, 16> Offsets;
-        ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts, &Offsets,
-                           0);
-
-        unsigned align =
-            getArgumentAlignment(Callee, CS, Ty, paramCount + 1, DL);
-        // declare .param .align <align> .b8 .param<n>[<size>];
-        unsigned sz = DL.getTypeAllocSize(Ty);
-        SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-        SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, dl,
-                                                             MVT::i32),
-                                      DAG.getConstant(paramCount, dl, MVT::i32),
-                                      DAG.getConstant(sz, dl, MVT::i32),
-                                      InFlag };
-        Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
-                            DeclareParamOps);
-        InFlag = Chain.getValue(1);
-        for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
-          EVT elemtype = vtparts[j];
-          unsigned ArgAlign = GreatestCommonDivisor64(align, Offsets[j]);
-          if (elemtype.isInteger() && (sz < 8))
-            sz = 8;
-          SDValue StVal = OutVals[OIdx];
-          if (elemtype.getSizeInBits() < 16) {
-            StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
-          }
-          SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-          SDValue CopyParamOps[] = { Chain,
-                                     DAG.getConstant(paramCount, dl, MVT::i32),
-                                     DAG.getConstant(Offsets[j], dl, MVT::i32),
-                                     StVal, InFlag };
-          Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
-                                          CopyParamVTs, CopyParamOps,
-                                          elemtype, MachinePointerInfo(),
-                                          ArgAlign);
-          InFlag = Chain.getValue(1);
-          ++OIdx;
-        }
-        if (vtparts.size() > 0)
-          --OIdx;
-        ++paramCount;
-        continue;
-      }
-      if (Ty->isVectorTy()) {
-        EVT ObjectVT = getValueType(DL, Ty);
-        unsigned align =
-            getArgumentAlignment(Callee, CS, Ty, paramCount + 1, DL);
-        // declare .param .align <align> .b8 .param<n>[<size>];
-        unsigned sz = DL.getTypeAllocSize(Ty);
-        SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-        SDValue DeclareParamOps[] = { Chain,
-                                      DAG.getConstant(align, dl, MVT::i32),
-                                      DAG.getConstant(paramCount, dl, MVT::i32),
-                                      DAG.getConstant(sz, dl, MVT::i32),
-                                      InFlag };
-        Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
-                            DeclareParamOps);
-        InFlag = Chain.getValue(1);
-        unsigned NumElts = ObjectVT.getVectorNumElements();
-        EVT EltVT = ObjectVT.getVectorElementType();
-        EVT MemVT = EltVT;
-        bool NeedExtend = false;
-        if (EltVT.getSizeInBits() < 16) {
-          NeedExtend = true;
-          EltVT = MVT::i16;
-        }
-
-        // V1 store
-        if (NumElts == 1) {
-          SDValue Elt = OutVals[OIdx++];
-          if (NeedExtend)
-            Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
-
-          SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-          SDValue CopyParamOps[] = { Chain,
-                                     DAG.getConstant(paramCount, dl, MVT::i32),
-                                     DAG.getConstant(0, dl, MVT::i32), Elt,
-                                     InFlag };
-          Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
-                                          CopyParamVTs, CopyParamOps,
-                                          MemVT, MachinePointerInfo());
-          InFlag = Chain.getValue(1);
-        } else if (NumElts == 2) {
-          SDValue Elt0 = OutVals[OIdx++];
-          SDValue Elt1 = OutVals[OIdx++];
-          if (NeedExtend) {
-            Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
-            Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
-          }
-
-          SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-          SDValue CopyParamOps[] = { Chain,
-                                     DAG.getConstant(paramCount, dl, MVT::i32),
-                                     DAG.getConstant(0, dl, MVT::i32), Elt0,
-                                     Elt1, InFlag };
-          Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
-                                          CopyParamVTs, CopyParamOps,
-                                          MemVT, MachinePointerInfo());
-          InFlag = Chain.getValue(1);
-        } else {
-          unsigned curOffset = 0;
-          // V4 stores
-          // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
-          // the
-          // vector will be expanded to a power of 2 elements, so we know we can
-          // always round up to the next multiple of 4 when creating the vector
-          // stores.
-          // e.g.  4 elem => 1 st.v4
-          //       6 elem => 2 st.v4
-          //       8 elem => 2 st.v4
-          //      11 elem => 3 st.v4
-          unsigned VecSize = 4;
-          if (EltVT.getSizeInBits() == 64)
-            VecSize = 2;
-
-          // This is potentially only part of a vector, so assume all elements
-          // are packed together.
-          unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
-
-          for (unsigned i = 0; i < NumElts; i += VecSize) {
-            // Get values
-            SDValue StoreVal;
-            SmallVector<SDValue, 8> Ops;
-            Ops.push_back(Chain);
-            Ops.push_back(DAG.getConstant(paramCount, dl, MVT::i32));
-            Ops.push_back(DAG.getConstant(curOffset, dl, MVT::i32));
-
-            unsigned Opc = NVPTXISD::StoreParamV2;
-
-            StoreVal = OutVals[OIdx++];
-            if (NeedExtend)
-              StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
-            Ops.push_back(StoreVal);
-
-            if (i + 1 < NumElts) {
-              StoreVal = OutVals[OIdx++];
-              if (NeedExtend)
-                StoreVal =
-                    DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
-            } else {
-              StoreVal = DAG.getUNDEF(EltVT);
-            }
-            Ops.push_back(StoreVal);
-
-            if (VecSize == 4) {
-              Opc = NVPTXISD::StoreParamV4;
-              if (i + 2 < NumElts) {
-                StoreVal = OutVals[OIdx++];
-                if (NeedExtend)
-                  StoreVal =
-                      DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
-              } else {
-                StoreVal = DAG.getUNDEF(EltVT);
-              }
-              Ops.push_back(StoreVal);
-
-              if (i + 3 < NumElts) {
-                StoreVal = OutVals[OIdx++];
-                if (NeedExtend)
-                  StoreVal =
-                      DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
-              } else {
-                StoreVal = DAG.getUNDEF(EltVT);
-              }
-              Ops.push_back(StoreVal);
-            }
-
-            Ops.push_back(InFlag);
-
-            SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-            Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, Ops,
-                                            MemVT, MachinePointerInfo());
-            InFlag = Chain.getValue(1);
-            curOffset += PerStoreOffset;
-          }
-        }
-        ++paramCount;
-        --OIdx;
-        continue;
-      }
-      // Plain scalar
-      // for ABI,    declare .param .b<size> .param<n>;
-      unsigned sz = VT.getSizeInBits();
-      bool needExtend = false;
-      if (VT.isInteger()) {
-        if (sz < 16)
-          needExtend = true;
-        if (sz < 32)
-          sz = 32;
-      } else if (VT.isFloatingPoint() && sz < 32)
-        // PTX ABI requires all scalar parameters to be at least 32
-        // bits in size.  fp16 normally uses .b16 as its storage type
-        // in PTX, so its size must be adjusted here, too.
-        sz = 32;
+      // aggregate
+      SmallVector<EVT, 16> VTs;
----------------
Does this comment still hold?


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:1456
+      // only to scalar parameters and not to aggregate values.
+      bool ExtendIntegerRetVal =
+          Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
----------------
Is this the right name?  We're dealing with function params, not a return value, I think?


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:1461
+      VectorizePTXValueVTs(VTs, Offsets, ArgAlign, VectorInfo);
+      SmallVector<SDValue, 6> LdStOps;
+      for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
----------------
`StoreOperands` or `StNodeOperands`?  "Ops" is a bad abbreviation for "operands".  :)  And it looks like this is specifically for stores?


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:1461
+      VectorizePTXValueVTs(VTs, Offsets, ArgAlign, VectorInfo);
+      SmallVector<SDValue, 6> LdStOps;
+      for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
----------------
jlebar wrote:
> `StoreOperands` or `StNodeOperands`?  "Ops" is a bad abbreviation for "operands".  :)  And it looks like this is specifically for stores?
Can we assert somewhere that this vector is empty when we're done with the loop?


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:1463
+      for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
+        // New load/store.
+        if (VectorInfo[j] & PTX_LDST_BEGIN) {
----------------
"New store."?


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:1480
+        } else if (EltVT.getSizeInBits() < 16) {
+          // Use 16-bit registers for small load-stores as it's the
+          // smallest general purpose register size supported by NVPTX.
----------------
s/load-stores/stores/?


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:1507
+
+          // Adjust type of load/store op if we've extended the scalar
+          // return value.
----------------
s#load/store#store#?


================
Comment at: lib/Target/NVPTX/NVPTXISelLowering.cpp:2400
 
-        StoreVal = OutVals[i];
-        if (NeedExtend)
-          StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
-        Ops.push_back(StoreVal);
-
-        if (i + 1 < NumElts) {
-          StoreVal = OutVals[i + 1];
-          if (NeedExtend)
-            StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
-        } else {
-          StoreVal = DAG.getUNDEF(ExtendedVT);
-        }
-        Ops.push_back(StoreVal);
-
-        if (VecSize == 4) {
-          Opc = NVPTXISD::StoreRetvalV4;
-          if (i + 2 < NumElts) {
-            StoreVal = OutVals[i + 2];
-            if (NeedExtend)
-              StoreVal =
-                  DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
-          } else {
-            StoreVal = DAG.getUNDEF(ExtendedVT);
-          }
-          Ops.push_back(StoreVal);
-
-          if (i + 3 < NumElts) {
-            StoreVal = OutVals[i + 3];
-            if (NeedExtend)
-              StoreVal =
-                  DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
-          } else {
-            StoreVal = DAG.getUNDEF(ExtendedVT);
-          }
-          Ops.push_back(StoreVal);
-        }
+  SmallVector<SDValue, 6> LdStOps;
+  for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
----------------
This looks oddly familiar from the parameter handling above.  We really can't factor it out without making this into more of a hairball?  If not, same comments apply down here.


https://reviews.llvm.org/D30011





More information about the llvm-commits mailing list