[llvm] [NVPTX] Add TMA bulk tensor copy intrinsics (PR #96083)
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 22 11:00:13 PDT 2024
================
@@ -4091,3 +4096,236 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
}
}
}
+
+static size_t GetCpAsyncBulkTensorDimFromIntrinsic(unsigned IID) {
+ switch (IID) {
+ case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_1d:
+ case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_1d:
+ return 1;
+ case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_2d:
+ case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_2d:
+ return 2;
+ case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_3d:
+ case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_3d:
+ return 3;
+ case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_4d:
+ case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_4d:
+ return 4;
+ case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_5d:
+ case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_5d:
+ return 5;
+ default:
+ llvm_unreachable(
+ "Invalid Tensor dim in nvvm_cp_async_bulk_tensor intrinsic");
+ }
+}
+
+#define CP_ASYNC_BULK_TENSOR_OPCODE(dir, dim, mode, suffix) \
+ (IsShared32 \
+ ? NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_SHARED32_##mode##suffix \
+ : NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_##mode##suffix)
+
+#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(dim, mode) \
+ (IsCacheHint ? (CP_ASYNC_BULK_TENSOR_OPCODE(SMEM_TO_GMEM, dim, mode, _CH)) \
+ : (CP_ASYNC_BULK_TENSOR_OPCODE(SMEM_TO_GMEM, dim, mode, )))
+
+#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode) \
+ [&]() -> auto { \
+ if (IsMultiCast && IsCacheHint) \
+ return CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _MC_CH); \
+ if (IsCacheHint) \
+ return CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _CH); \
+ if (IsMultiCast) \
+ return CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _MC); \
+ return CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, ); \
+ }()
+
+static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32,
+ bool IsCacheHint, bool IsIm2Col) {
+ if (IsIm2Col) {
+ switch (Dim) {
+ case 3:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, IM2COL);
+ case 4:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, IM2COL);
+ case 5:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, IM2COL);
+ default:
+ llvm_unreachable("Invalid Dimension in im2col mode for "
+ "GetCpAsyncBulkTensorS2GOpcode.");
+ }
+ } else {
+ switch (Dim) {
+ case 1:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(1D, TILE);
+ case 2:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(2D, TILE);
+ case 3:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, TILE);
+ case 4:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, TILE);
+ case 5:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, TILE);
+ default:
+ llvm_unreachable(
+ "Invalid Dimension in tile mode for GetCpAsyncBulkTensorS2GOpcode.");
+ }
+ }
+}
+
+static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32,
+ bool IsMultiCast,
+ bool IsCacheHint, bool IsIm2Col) {
+ if (IsIm2Col) {
+ switch (Dim) {
+ case 3:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL);
+ case 4:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL);
+ case 5:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL);
+ default:
+ llvm_unreachable("Invalid Dimension in im2col mode for "
+ "GetCpAsyncBulkTensorG2SOpcode.");
+ }
+ } else {
+ switch (Dim) {
+ case 1:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE);
+ case 2:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE);
+ case 3:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE);
+ case 4:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE);
+ case 5:
+ return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE);
+ default:
+ llvm_unreachable(
+ "Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode.");
+ }
+ }
+}
+
+void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorS2G(SDNode *N) {
+ unsigned int SharedPointerSize =
+ CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED);
+ bool IsShared32 = (SharedPointerSize == 32);
+
+ unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+ size_t NumDims = GetCpAsyncBulkTensorDimFromIntrinsic(IID);
+
+ ConstantSDNode *FlagsNode = cast<ConstantSDNode>(N->getOperand(2));
+ nvvm::CpAsyncBulkTensorFlags Flags;
+ Flags.V = static_cast<unsigned>(FlagsNode->getZExtValue());
+ bool IsCacheHint = Flags.U.CacheHint == 1;
+ bool IsIm2Col = Flags.U.LoadMode == 1;
+
+ SDLoc DL(N);
+ // List of operands that are common to both variants
+ SmallVector<SDValue, 4> Ops{
+ N->getOperand(3), // Src pointer in smem
+ N->getOperand(4), // Dst tensor_map pointer in gmem
+ };
+
+ // Tensor Dims from [1-5] followed by the cache-hint operand
+ size_t TensorDimsStartIndex = 5;
+ size_t CacheHintIndex = TensorDimsStartIndex + NumDims;
+ for (size_t i = 0; i < NumDims; i++)
+ Ops.push_back(N->getOperand(TensorDimsStartIndex + i));
+
+ // Push the cache-hint operand, if available
+ if (IsCacheHint)
+ Ops.push_back(N->getOperand(CacheHintIndex));
+
+ // Finally, the chain operand
+ Ops.push_back(N->getOperand(0));
+
+ unsigned Opcode =
+ GetCpAsyncBulkTensorS2GOpcode(NumDims, IsShared32, IsCacheHint, IsIm2Col);
+
+ ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
+}
+
+void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2S(SDNode *N) {
+ unsigned int SharedPointerSize =
+ CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED);
+ bool IsShared32 = (SharedPointerSize == 32);
+
+ unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+ size_t NumDims = GetCpAsyncBulkTensorDimFromIntrinsic(IID);
+
+ ConstantSDNode *FlagsNode = cast<ConstantSDNode>(N->getOperand(2));
+ nvvm::CpAsyncBulkTensorFlags Flags;
+ Flags.V = static_cast<unsigned>(FlagsNode->getZExtValue());
----------------
Artem-B wrote:
I guess here you would technically need to do something like this:
```
assert(sizeof(int_with_value) == sizeof(Flags.U));
memcpy(&Flags.U, int_with_value, sizeof(Flags.U));
```
Then the fields in `Flags.U` should be valid to access.
https://github.com/llvm/llvm-project/pull/96083
More information about the llvm-commits
mailing list