[llvm] [NVPTX] Add TMA bulk tensor copy intrinsics (PR #96083)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 4 11:49:50 PST 2024


================
@@ -4150,3 +4154,235 @@ NVPTX::Scope NVPTXScopes::operator[](SyncScope::ID ID) const {
 }
 
 bool NVPTXScopes::empty() const { return Scopes.size() == 0; }
+
+#define CP_ASYNC_BULK_TENSOR_OPCODE(dir, dim, mode, suffix)                    \
+  (IsShared32                                                                  \
+       ? NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_SHARED32_##mode##suffix   \
+       : NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_##mode##suffix)
+
+#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(dim, mode)                         \
+  (IsCacheHint ? (CP_ASYNC_BULK_TENSOR_OPCODE(S2G, dim, mode, _CH))            \
+               : (CP_ASYNC_BULK_TENSOR_OPCODE(S2G, dim, mode, )))
+
+#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode)                         \
+  [&]() -> auto {                                                              \
+    if (IsMultiCast && IsCacheHint)                                            \
+      return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, _MC_CH);              \
+    if (IsCacheHint)                                                           \
+      return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, _CH);                 \
+    if (IsMultiCast)                                                           \
+      return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, _MC);                 \
+    return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, );                      \
+  }()
+
+static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32,
+                                              bool IsCacheHint, bool IsIm2Col) {
+  if (IsIm2Col) {
+    switch (Dim) {
+    case 3:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, IM2COL);
+    case 4:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, IM2COL);
+    case 5:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, IM2COL);
+    default:
+      llvm_unreachable("Invalid Dimension in im2col mode for "
+                       "GetCpAsyncBulkTensorS2GOpcode.");
+    }
+  } else {
+    switch (Dim) {
+    case 1:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(1D, TILE);
+    case 2:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(2D, TILE);
+    case 3:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, TILE);
+    case 4:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, TILE);
+    case 5:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, TILE);
+    default:
+      llvm_unreachable(
+          "Invalid Dimension in tile mode for GetCpAsyncBulkTensorS2GOpcode.");
+    }
+  }
+}
+
+static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32,
+                                              bool IsMultiCast,
+                                              bool IsCacheHint, bool IsIm2Col) {
+  if (IsIm2Col) {
+    switch (Dim) {
+    case 3:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL);
+    case 4:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL);
+    case 5:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL);
+    default:
+      llvm_unreachable("Invalid Dimension in im2col mode for "
+                       "GetCpAsyncBulkTensorG2SOpcode.");
+    }
+  } else {
+    switch (Dim) {
+    case 1:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE);
+    case 2:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE);
+    case 3:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE);
+    case 4:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE);
+    case 5:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE);
+    default:
+      llvm_unreachable(
+          "Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode.");
+    }
+  }
+}
+
+void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SIm2Col(SDNode *N) {
+  // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
+  // {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2}
+  // multicast, cache_hint,
+  // multicast_flag, cache_hint_flag}
+  const std::map<unsigned, size_t> IntrinsicToDims = {
+      {Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d, 3},
+      {Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d, 4},
+      {Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d, 5}};
+
+  size_t NumDims = IntrinsicToDims.at(N->getConstantOperandVal(1));
+  size_t NumOffsets = NumDims - 2; // Offsets is always 'NumDims - 2'
+  size_t NumOps = N->getNumOperands();
+  bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
+  bool IsMultiCast = N->getConstantOperandVal(NumOps - 2) == 1;
+
+  SDLoc DL(N);
+  SmallVector<SDValue, 8> Ops{
+      N->getOperand(2), // Dst pointer in smem
+      N->getOperand(3), // Mbarrier pointer in smem
+      N->getOperand(4), // Src pointer (i.e. tensor_map) in gmem
+  };
+
+  // Tensor Dims from [1-5]
+  size_t Idx;
+  for (Idx = 5; Idx < (NumDims + NumOffsets + 5); Idx++)
+    Ops.push_back(N->getOperand(Idx));
+
+  // Push MultiCast operand, if available
+  if (IsMultiCast)
+    Ops.push_back(N->getOperand(Idx));
+
+  // Push CacheHint operand, if available
+  if (IsCacheHint)
+    Ops.push_back(N->getOperand(Idx + 1));
+
+  // Finally, the chain operand
+  Ops.push_back(N->getOperand(0));
+
+  bool IsShared32 =
+      CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
+  unsigned Opcode = GetCpAsyncBulkTensorG2SOpcode(NumDims, IsShared32,
+                                                  IsMultiCast, IsCacheHint, 1);
+  ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
+}
+
+void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2STile(SDNode *N) {
+  // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
+  // {dst, mbar, src, dims{d0...dN}, multicast, cache_hint,
+  // multicast_flag, cache_hint_flag}
+  // NumOperands = {Chain, IID} + {Actual intrinsic args}
+  //             = {2}          + {7 + dims}
+  size_t NumOps = N->getNumOperands();
+  size_t NumDims = NumOps - 9;
+  bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
+  bool IsMultiCast = N->getConstantOperandVal(NumOps - 2) == 1;
+
+  SDLoc DL(N);
+  SmallVector<SDValue, 8> Ops{
+      N->getOperand(2), // Dst pointer in smem
+      N->getOperand(3), // Mbarrier pointer in smem
+      N->getOperand(4), // Src pointer (i.e. tensor_map) in gmem
+  };
+
+  // Tensor Dims from [1-5]
+  size_t Idx;
+  for (Idx = 5; Idx < (NumDims + 5); Idx++)
+    Ops.push_back(N->getOperand(Idx));
+
+  // Push MultiCast operand, if available
+  if (IsMultiCast)
+    Ops.push_back(N->getOperand(Idx));
+
+  // Push CacheHint operand, if available
+  if (IsCacheHint)
+    Ops.push_back(N->getOperand(Idx + 1));
+
+  // Finally, the chain operand
+  Ops.push_back(N->getOperand(0));
+
+  bool IsShared32 =
+      CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
+  unsigned Opcode = GetCpAsyncBulkTensorG2SOpcode(NumDims, IsShared32,
+                                                  IsMultiCast, IsCacheHint, 0);
+  ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
+}
+
+void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorS2GCommon(SDNode *N,
+                                                         bool IsIm2Col) {
+  // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
+  // src, dst, dims{d0...dN}, cache_hint, cache_hint_flag
+  // NumOperands = {Chain, IID} + {Actual intrinsic args}
+  //             = {2}          + {4 + dims}
+  size_t NumOps = N->getNumOperands();
+  size_t NumDims = NumOps - 6;
+  bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
+  size_t EndIdx = IsCacheHint ? (NumOps - 1) : (NumOps - 2);
+
+  SDLoc DL(N);
+  SmallVector<SDValue, 8> Ops;
+  for (size_t i = 2; i < EndIdx; i++)
+    Ops.push_back(N->getOperand(i));
+
+  // Finally, the chain operand
+  Ops.push_back(N->getOperand(0));
+
+  bool IsShared32 =
+      CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
+  unsigned Opcode =
+      GetCpAsyncBulkTensorS2GOpcode(NumDims, IsShared32, IsCacheHint, IsIm2Col);
+  ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
+}
+
+bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
+  unsigned IID = N->getConstantOperandVal(1);
+  switch (IID) {
+  default:
+    return false;
+  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d:
+    SelectCpAsyncBulkTensorS2GCommon(N);
+    return true;
+  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d:
+    SelectCpAsyncBulkTensorS2GCommon(N, true /* IsIm2Col */);
----------------
Artem-B wrote:

Style nit: I think it's conventionally `/*ArgName=*/ value`

https://github.com/llvm/llvm-project/pull/96083


More information about the llvm-commits mailing list