[Mlir-commits] [mlir] [mlir][nvgpu] Use the strides of the memref descriptor to construct the TMA descriptor (PR #85403)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 15 07:10:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Adam Paszke (apaszke)

<details>
<summary>Changes</summary>

The previous version of the code assumed that the tensor was contiguous, which is not required and can cause surprising miscompiles.

---
Full diff: https://github.com/llvm/llvm-project/pull/85403.diff


1 Files Affected:

- (modified) mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp (+19-13) 


``````````diff
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index b9a3429e37b885..c76f8d77dff558 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -427,13 +427,21 @@ namespace {
 
 template <int rank>
 void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr,
-                               uint64_t *globalDim) {
+                               uint64_t *globalDim, uint64_t *globalStrides,
+                               const CUtensorMapDataType tensorDataType) {
   auto descriptor =
       reinterpret_cast<StridedMemRefType<char, rank> *>(raw_descriptor);
   *addr = descriptor->data;
   for (int i = 0; i < rank; ++i) {
     globalDim[i] = static_cast<uint64_t>(descriptor->sizes[rank - i - 1]);
   }
+  static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
+                                               4, 8, 2, 4, 4, 4};
+  // TODO(grypp): Check that the minormost stride is equal to the element size.
+  for (int i = 0; i < rank - 1; ++i) {
+    globalStrides[i] = static_cast<uint64_t>(
+        descriptor->strides[rank - i - 2] * elementSizeInBytes[tensorDataType]);
+  }
 }
 
 } // namespace
@@ -457,19 +465,24 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
   char *globalAddress = nullptr;
   switch (tensorRank) {
   case 1:
-    mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 2:
-    mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 3:
-    mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 4:
-    mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 5:
-    mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   default:
     fprintf(
@@ -478,17 +491,10 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
     return NULL;
   }
 
-  static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
-                                           4, 8, 2, 4, 4, 4};
   for (int64_t r = 0; r < tensorRank; ++r) {
-    elementStrides[r] = uint32_t(1);
     boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]);
   }
 
-  globalStrides[0] = globalDim[0] * elementSizeInBytes[tensorDataType];
-  for (int r = 1; r < tensorRank - 1; r++)
-    globalStrides[r] = globalStrides[r - 1] * globalDim[r];
-
   ScopedContext scopedContext;
   mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32,
                            globalAddress, globalDim, globalStrides, boxDim,

``````````

</details>


https://github.com/llvm/llvm-project/pull/85403


More information about the Mlir-commits mailing list