[Mlir-commits] [mlir] [mlir][cuda] Avoid driver call to check max shared memory (PR #70021)

Guray Ozen llvmlistbot at llvm.org
Tue Oct 24 03:12:56 PDT 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/70021

This PR guards the driver call with if-statement as the driver calls are more expensive.

As a future todo, the if statement could be generated by the compiler and thus optimized in some cases.

>From 0566391cf03a1a5a2c4fdd3223e91cf8d59576ca Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 24 Oct 2023 12:04:22 +0200
Subject: [PATCH] [mlir][cuda] Avoid driver call to check max shared memory

This PR guards the driver call with if-statement as the driver calls are more expensive.

As a future todo, the if statement could be generated by the compiler and thus optimized in some cases.
---
 .../ExecutionEngine/CudaRuntimeWrappers.cpp   | 29 ++++++++++---------
 1 file changed, 16 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 55db744af021c14..a8e743c519135f7 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -168,20 +168,23 @@ mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY,
                  intptr_t blockZ, int32_t smem, CUstream stream, void **params,
                  void **extra, size_t /*paramsCount*/) {
   ScopedContext scopedContext;
-  int32_t maxShmem = 0;
-  CUdevice device = getDefaultCuDevice();
-  CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
-  CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
-      &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
-      device));
-  if (maxShmem < smem) {
-    fprintf(stderr,
-            "Requested shared memory (%dkb) is larger than maximum allowed "
-            "shared memory (%dkb) for this device\n",
-            smem, maxShmem);
+  if (smem > 0) {
+    // Avoid checking driver as it's more expensive than if statement
+    int32_t maxShmem = 0;
+    CUdevice device = getDefaultCuDevice();
+    CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
+    CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
+        &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
+        device));
+    if (maxShmem < smem) {
+      fprintf(stderr,
+              "Requested shared memory (%dkb) is larger than maximum allowed "
+              "shared memory (%dkb) for this device\n",
+              smem, maxShmem);
+    }
+    CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
+        function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
   }
-  CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
-      function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
   debug_print("Launching kernel, grid=%ld,%ld,%ld, "
               "threads: %ld, %ld, %ld, "
               "smem: %dkb\n",



More information about the Mlir-commits mailing list