[Mlir-commits] [mlir] [mlir][gpu] Generate multiple rank-specializations for tensor map cre… (PR #74082)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 1 06:27:18 PST 2023


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff f19571ee781de932390e8983267263f504e99e1f 3b5e8211a04fb85b596fc08b7d0e29916d87b9f6 -- mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 370f8eabe7..8b1fb0a559 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -425,22 +425,24 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuTensorMapEncodeTiled(
 
 namespace {
 
-template<int rank>
-void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr, uint64_t *globalDim) {
-  auto descriptor = reinterpret_cast<StridedMemRefType<char, rank>*>(raw_descriptor);
+template <int rank>
+void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr,
+                               uint64_t *globalDim) {
+  auto descriptor =
+      reinterpret_cast<StridedMemRefType<char, rank> *>(raw_descriptor);
   *addr = descriptor->data;
-  if constexpr (rank > 0) {  // rank 0 memrefs have no sizes
+  if constexpr (rank > 0) { // rank 0 memrefs have no sizes
     for (int i = 0; i < rank; ++i) {
       globalDim[i] = static_cast<uint64_t>(descriptor->sizes[rank - i - 1]);
     }
   }
 }
 
-}  // namespace
+} // namespace
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
     int64_t tensorRank,                       // Dimensionality of tensor
-    void *ranked_descriptor,   // Starting address
+    void *ranked_descriptor,                  // Starting address
     const CUtensorMapDataType tensorDataType, // Stride size (in bytes)
     CUtensorMapInterleave interleave,         // Type of interleaved layout
     CUtensorMapSwizzle swizzle,               // Bank swizzling pattern
@@ -456,27 +458,29 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
 
   char *globalAddress = nullptr;
   switch (tensorRank) {
-    case 0:
-      mgpuGetMemRefDataAndShape<0>(ranked_descriptor, &globalAddress, globalDim);
-      break;
-    case 1:
-      mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim);
-      break;
-    case 2:
-      mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim);
-      break;
-    case 3:
-      mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim);
-      break;
-    case 4:
-      mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim);
-      break;
-    case 5:
-      mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim);
-      break;
-    default:
-      fprintf(stderr, "'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n");
-      return NULL;
+  case 0:
+    mgpuGetMemRefDataAndShape<0>(ranked_descriptor, &globalAddress, globalDim);
+    break;
+  case 1:
+    mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim);
+    break;
+  case 2:
+    mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim);
+    break;
+  case 3:
+    mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim);
+    break;
+  case 4:
+    mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim);
+    break;
+  case 5:
+    mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim);
+    break;
+  default:
+    fprintf(
+        stderr,
+        "'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n");
+    return NULL;
   }
 
   static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,

``````````

</details>


https://github.com/llvm/llvm-project/pull/74082


More information about the Mlir-commits mailing list