[Mlir-commits] [mlir] f69d5a7 - [mlir] Initialize CUDA context lazily.

Christian Sigg llvmlistbot at llvm.org
Thu Mar 4 04:08:05 PST 2021


Author: Christian Sigg
Date: 2021-03-04T13:07:56+01:00
New Revision: f69d5a7fc7e4f69580421ec47396f23d03cde0d0

URL: https://github.com/llvm/llvm-project/commit/f69d5a7fc7e4f69580421ec47396f23d03cde0d0
DIFF: https://github.com/llvm/llvm-project/commit/f69d5a7fc7e4f69580421ec47396f23d03cde0d0.diff

LOG: [mlir] Initialize CUDA context lazily.

So we can remove the ignore-warning pragma again.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D97864

Added: 
    

Modified: 
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 8afb80e92e23..9122e34ba453 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -37,32 +37,26 @@
     fprintf(stderr, "'%s' failed with '%s'\n", #expr, name);                   \
   }(expr)
 
-#pragma clang diagnostic push
-#pragma clang diagnostic ignored "-Wglobal-constructors"
-// Static reference to CUDA primary context for device ordinal 0.
-static CUcontext Context = [] {
-  CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
-  CUdevice device;
-  CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
-  CUcontext context;
-  CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device));
-  return context;
-}();
-#pragma clang diagnostic pop
-
-// Sets the `Context` for the duration of the instance and restores the previous
-// context on destruction.
+// Make the primary context of device 0 current for the duration of the instance
+// and restore the previous context on destruction.
 class ScopedContext {
 public:
   ScopedContext() {
-    CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous));
-    CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context));
+    // Static reference to CUDA primary context for device ordinal 0.
+    static CUcontext context = [] {
+      CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
+      CUdevice device;
+      CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
+      CUcontext ctx;
+      // Note: this does not affect the current context.
+      CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device));
+      return ctx;
+    }();
+
+    CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context));
   }
 
-  ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); }
-
-private:
-  CUcontext previous;
+  ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); }
 };
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) {


        


More information about the Mlir-commits mailing list