[Mlir-commits] [mlir] [mlir][nvgpu] Use the strides of the memref descriptor to construct the TMA descriptor (PR #85403)
Adam Paszke
llvmlistbot at llvm.org
Fri Mar 15 07:10:24 PDT 2024
https://github.com/apaszke created https://github.com/llvm/llvm-project/pull/85403
The previous version of the code assumed that the tensor was contiguous, which is not required and can cause surprising miscompiles.
>From aa215e23acf70cb5abf178571a058d737729d6e5 Mon Sep 17 00:00:00 2001
From: Adam Paszke <apaszke at google.com>
Date: Fri, 15 Mar 2024 14:09:25 +0000
Subject: [PATCH] Use the strides of the memref descriptor to construct the TMA
descriptor
The previous version of the code assumed that the tensor was contiguous,
which is not required and can cause surprising miscompiles.
---
.../ExecutionEngine/CudaRuntimeWrappers.cpp | 32 +++++++++++--------
1 file changed, 19 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index b9a3429e37b885..c76f8d77dff558 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -427,13 +427,21 @@ namespace {
template <int rank>
void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr,
- uint64_t *globalDim) {
+ uint64_t *globalDim, uint64_t *globalStrides,
+ const CUtensorMapDataType tensorDataType) {
auto descriptor =
reinterpret_cast<StridedMemRefType<char, rank> *>(raw_descriptor);
*addr = descriptor->data;
for (int i = 0; i < rank; ++i) {
globalDim[i] = static_cast<uint64_t>(descriptor->sizes[rank - i - 1]);
}
+ static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
+ 4, 8, 2, 4, 4, 4};
+ // TODO(grypp): Check that the minormost stride is equal to the element size.
+ for (int i = 0; i < rank - 1; ++i) {
+ globalStrides[i] = static_cast<uint64_t>(
+ descriptor->strides[rank - i - 2] * elementSizeInBytes[tensorDataType]);
+ }
}
} // namespace
@@ -457,19 +465,24 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
char *globalAddress = nullptr;
switch (tensorRank) {
case 1:
- mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim);
+ mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim,
+ globalStrides, tensorDataType);
break;
case 2:
- mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim);
+ mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim,
+ globalStrides, tensorDataType);
break;
case 3:
- mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim);
+ mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim,
+ globalStrides, tensorDataType);
break;
case 4:
- mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim);
+ mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim,
+ globalStrides, tensorDataType);
break;
case 5:
- mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim);
+ mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim,
+ globalStrides, tensorDataType);
break;
default:
fprintf(
@@ -478,17 +491,10 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
return NULL;
}
- static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
- 4, 8, 2, 4, 4, 4};
for (int64_t r = 0; r < tensorRank; ++r) {
- elementStrides[r] = uint32_t(1);
boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]);
}
- globalStrides[0] = globalDim[0] * elementSizeInBytes[tensorDataType];
- for (int r = 1; r < tensorRank - 1; r++)
- globalStrides[r] = globalStrides[r - 1] * globalDim[r];
-
ScopedContext scopedContext;
mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32,
globalAddress, globalDim, globalStrides, boxDim,
More information about the Mlir-commits
mailing list