[Mlir-commits] [mlir] [mlir][nvgpu] Use the strides of the memref descriptor to construct the TMA descriptor (PR #85403)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 15 07:10:52 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Adam Paszke (apaszke)
<details>
<summary>Changes</summary>
The previous version of the code assumed that the tensor was contiguous, which is not required and can cause surprising miscompiles.
---
Full diff: https://github.com/llvm/llvm-project/pull/85403.diff
1 Files Affected:
- (modified) mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp (+19-13)
``````````diff
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index b9a3429e37b885..c76f8d77dff558 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -427,13 +427,21 @@ namespace {
template <int rank>
void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr,
- uint64_t *globalDim) {
+ uint64_t *globalDim, uint64_t *globalStrides,
+ const CUtensorMapDataType tensorDataType) {
auto descriptor =
reinterpret_cast<StridedMemRefType<char, rank> *>(raw_descriptor);
*addr = descriptor->data;
for (int i = 0; i < rank; ++i) {
globalDim[i] = static_cast<uint64_t>(descriptor->sizes[rank - i - 1]);
}
+ static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
+ 4, 8, 2, 4, 4, 4};
+ // TODO(grypp): Check that the minormost stride is equal to the element size.
+ for (int i = 0; i < rank - 1; ++i) {
+ globalStrides[i] = static_cast<uint64_t>(
+ descriptor->strides[rank - i - 2] * elementSizeInBytes[tensorDataType]);
+ }
}
} // namespace
@@ -457,19 +465,24 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
char *globalAddress = nullptr;
switch (tensorRank) {
case 1:
- mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim);
+ mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim,
+ globalStrides, tensorDataType);
break;
case 2:
- mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim);
+ mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim,
+ globalStrides, tensorDataType);
break;
case 3:
- mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim);
+ mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim,
+ globalStrides, tensorDataType);
break;
case 4:
- mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim);
+ mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim,
+ globalStrides, tensorDataType);
break;
case 5:
- mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim);
+ mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim,
+ globalStrides, tensorDataType);
break;
default:
fprintf(
@@ -478,17 +491,10 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
return NULL;
}
- static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
- 4, 8, 2, 4, 4, 4};
for (int64_t r = 0; r < tensorRank; ++r) {
- elementStrides[r] = uint32_t(1);
boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]);
}
- globalStrides[0] = globalDim[0] * elementSizeInBytes[tensorDataType];
- for (int r = 1; r < tensorRank - 1; r++)
- globalStrides[r] = globalStrides[r - 1] * globalDim[r];
-
ScopedContext scopedContext;
mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32,
globalAddress, globalDim, globalStrides, boxDim,
``````````
</details>
https://github.com/llvm/llvm-project/pull/85403
More information about the Mlir-commits
mailing list