[Parallel_libs-commits] [PATCH] D24619: [SE] Cache CUDA modules

Jason Henline via Parallel_libs-commits parallel_libs-commits at lists.llvm.org
Thu Sep 15 12:42:30 PDT 2016


jhen created this revision.
jhen added a reviewer: jlebar.
jhen added subscribers: parallel_libs-commits, jprice.
Herald added a subscriber: jlebar.

Instead of reloading a module if the same kernel is requested multiple times,
cache the loaded module and return the cached value.

The CUDAPlatformDevice now also keeps handles to all its modules so they can be
unloaded if the device is cleared.

https://reviews.llvm.org/D24619

Files:
  streamexecutor/include/streamexecutor/platforms/cuda/CUDAPlatformDevice.h
  streamexecutor/lib/platforms/cuda/CUDAPlatformDevice.cpp

Index: streamexecutor/lib/platforms/cuda/CUDAPlatformDevice.cpp
===================================================================
--- streamexecutor/lib/platforms/cuda/CUDAPlatformDevice.cpp
+++ streamexecutor/lib/platforms/cuda/CUDAPlatformDevice.cpp
@@ -90,7 +90,6 @@
 
 Expected<const void *>
 CUDAPlatformDevice::createKernel(const MultiKernelLoaderSpec &Spec) {
-  // TODO(jhen): Maybe first check loaded modules?
   if (!Spec.hasCUDAPTXInMemory())
     return make_error("no CUDA code available to create kernel");
 
@@ -117,27 +116,26 @@
                       llvm::Twine(ComputeCapabilityMajor) + "." +
                       llvm::Twine(ComputeCapabilityMinor));
 
-  CUmodule Module;
-  if (CUresult Result = cuModuleLoadData(&Module, Code))
-    return CUresultToError(Result, "cuModuleLoadData");
-
-  CUfunction Function;
-  if (CUresult Result =
-          cuModuleGetFunction(&Function, Module, Spec.getKernelName().c_str()))
-    return CUresultToError(Result, "cuModuleGetFunction");
-
-  // TODO(jhen): Should I save this function pointer in case someone asks for
-  // it again?
-
-  // TODO(jhen): Should I save the module pointer so I can unload it when I
-  // destroy this device?
+  CUfunction Function = nullptr;
+  {
+    llvm::sys::ScopedLock Lock(Mutex);
+    auto Iterator = LoadedModules.find(Code);
+    if (Iterator == LoadedModules.end()) {
+      CUmodule Module = nullptr;
+      if (CUresult Result = cuModuleLoadData(&Module, Code))
+        return CUresultToError(Result, "cuModuleLoadData");
+      if (CUresult Result = cuModuleGetFunction(&Function, Module,
+                                                Spec.getKernelName().c_str()))
+        return CUresultToError(Result, "cuModuleGetFunction");
+      LoadedModules.emplace(Code, std::make_pair(Module, Function));
+    } else
+      Function = Iterator->second.second;
+  }
 
   return static_cast<const void *>(Function);
 }
 
 Error CUDAPlatformDevice::destroyKernel(const void *Handle) {
-  // TODO(jhen): Maybe keep track of kernels for each module and unload the
-  // module after they are all destroyed.
   return Error::success();
 }
 
Index: streamexecutor/include/streamexecutor/platforms/cuda/CUDAPlatformDevice.h
===================================================================
--- streamexecutor/include/streamexecutor/platforms/cuda/CUDAPlatformDevice.h
+++ streamexecutor/include/streamexecutor/platforms/cuda/CUDAPlatformDevice.h
@@ -17,6 +17,13 @@
 
 #include "streamexecutor/PlatformDevice.h"
 
+#include "llvm/Support/Mutex.h"
+
+#include <map>
+
+struct CUfunc_st;
+struct CUmod_st;
+
 namespace streamexecutor {
 namespace cuda {
 
@@ -85,6 +92,8 @@
   CUDAPlatformDevice(size_t DeviceIndex) : DeviceIndex(DeviceIndex) {}
 
   int DeviceIndex;
+  llvm::sys::Mutex Mutex;
+  std::map<std::string, std::pair<CUmod_st *, CUfunc_st *>> LoadedModules;
 };
 
 } // namespace cuda


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D24619.71536.patch
Type: text/x-patch
Size: 2900 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/parallel_libs-commits/attachments/20160915/470ba4c4/attachment.bin>


More information about the Parallel_libs-commits mailing list