[llvm] [NVPTX] Add TMA bulk tensor copy intrinsics (PR #96083)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 22 11:00:13 PDT 2024


================
@@ -4091,3 +4096,236 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
     }
   }
 }
+
+static size_t GetCpAsyncBulkTensorDimFromIntrinsic(unsigned IID) {
+  switch (IID) {
+  case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_1d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_1d:
+    return 1;
+  case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_2d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_2d:
+    return 2;
+  case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_3d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_3d:
+    return 3;
+  case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_4d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_4d:
+    return 4;
+  case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_5d:
+  case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_5d:
+    return 5;
+  default:
+    llvm_unreachable(
+        "Invalid Tensor dim in nvvm_cp_async_bulk_tensor intrinsic");
+  }
+}
+
+#define CP_ASYNC_BULK_TENSOR_OPCODE(dir, dim, mode, suffix)                    \
+  (IsShared32                                                                  \
+       ? NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_SHARED32_##mode##suffix   \
+       : NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_##mode##suffix)
+
+#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(dim, mode)                         \
+  (IsCacheHint ? (CP_ASYNC_BULK_TENSOR_OPCODE(SMEM_TO_GMEM, dim, mode, _CH))   \
+               : (CP_ASYNC_BULK_TENSOR_OPCODE(SMEM_TO_GMEM, dim, mode, )))
+
+#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode)                         \
+  [&]() -> auto {                                                              \
+    if (IsMultiCast && IsCacheHint)                                            \
+      return CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _MC_CH);     \
+    if (IsCacheHint)                                                           \
+      return CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _CH);        \
+    if (IsMultiCast)                                                           \
+      return CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _MC);        \
+    return CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, );             \
+  }()
+
+static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32,
+                                              bool IsCacheHint, bool IsIm2Col) {
+  if (IsIm2Col) {
+    switch (Dim) {
+    case 3:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, IM2COL);
+    case 4:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, IM2COL);
+    case 5:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, IM2COL);
+    default:
+      llvm_unreachable("Invalid Dimension in im2col mode for "
+                       "GetCpAsyncBulkTensorS2GOpcode.");
+    }
+  } else {
+    switch (Dim) {
+    case 1:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(1D, TILE);
+    case 2:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(2D, TILE);
+    case 3:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, TILE);
+    case 4:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, TILE);
+    case 5:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, TILE);
+    default:
+      llvm_unreachable(
+          "Invalid Dimension in tile mode for GetCpAsyncBulkTensorS2GOpcode.");
+    }
+  }
+}
+
+static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32,
+                                              bool IsMultiCast,
+                                              bool IsCacheHint, bool IsIm2Col) {
+  if (IsIm2Col) {
+    switch (Dim) {
+    case 3:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL);
+    case 4:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL);
+    case 5:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL);
+    default:
+      llvm_unreachable("Invalid Dimension in im2col mode for "
+                       "GetCpAsyncBulkTensorG2SOpcode.");
+    }
+  } else {
+    switch (Dim) {
+    case 1:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE);
+    case 2:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE);
+    case 3:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE);
+    case 4:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE);
+    case 5:
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE);
+    default:
+      llvm_unreachable(
+          "Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode.");
+    }
+  }
+}
+
+void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorS2G(SDNode *N) {
+  unsigned int SharedPointerSize =
+      CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED);
+  bool IsShared32 = (SharedPointerSize == 32);
+
+  unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+  size_t NumDims = GetCpAsyncBulkTensorDimFromIntrinsic(IID);
+
+  ConstantSDNode *FlagsNode = cast<ConstantSDNode>(N->getOperand(2));
+  nvvm::CpAsyncBulkTensorFlags Flags;
+  Flags.V = static_cast<unsigned>(FlagsNode->getZExtValue());
+  bool IsCacheHint = Flags.U.CacheHint == 1;
+  bool IsIm2Col = Flags.U.LoadMode == 1;
+
+  SDLoc DL(N);
+  // List of operands that are common to both variants
+  SmallVector<SDValue, 4> Ops{
+      N->getOperand(3), // Src pointer in smem
+      N->getOperand(4), // Dst tensor_map pointer in gmem
+  };
+
+  // Tensor Dims from [1-5] followed by the cache-hint operand
+  size_t TensorDimsStartIndex = 5;
+  size_t CacheHintIndex = TensorDimsStartIndex + NumDims;
+  for (size_t i = 0; i < NumDims; i++)
+    Ops.push_back(N->getOperand(TensorDimsStartIndex + i));
+
+  // Push the cache-hint operand, if available
+  if (IsCacheHint)
+    Ops.push_back(N->getOperand(CacheHintIndex));
+
+  // Finally, the chain operand
+  Ops.push_back(N->getOperand(0));
+
+  unsigned Opcode =
+      GetCpAsyncBulkTensorS2GOpcode(NumDims, IsShared32, IsCacheHint, IsIm2Col);
+
+  ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
+}
+
+void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2S(SDNode *N) {
+  unsigned int SharedPointerSize =
+      CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED);
+  bool IsShared32 = (SharedPointerSize == 32);
+
+  unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+  size_t NumDims = GetCpAsyncBulkTensorDimFromIntrinsic(IID);
+
+  ConstantSDNode *FlagsNode = cast<ConstantSDNode>(N->getOperand(2));
+  nvvm::CpAsyncBulkTensorFlags Flags;
+  Flags.V = static_cast<unsigned>(FlagsNode->getZExtValue());
----------------
Artem-B wrote:

I guess here you would technically need to do something like this:
```
assert(sizeof(int_with_value) == sizeof(Flags.U));
memcpy(&Flags.U, int_with_value, sizeof(Flags.U));
```

Then the fields in `Flags.U` should be valid to access.


https://github.com/llvm/llvm-project/pull/96083


More information about the llvm-commits mailing list