[Mlir-commits] [mlir] c64c04b - Clean up cuda-runtime-wrappers API.

Christian Sigg llvmlistbot at llvm.org
Tue Jul 28 07:34:19 PDT 2020


Author: Christian Sigg
Date: 2020-07-28T16:34:08+02:00
New Revision: c64c04bbaadbc35e265f12644b45787d6d077587

URL: https://github.com/llvm/llvm-project/commit/c64c04bbaadbc35e265f12644b45787d6d077587
DIFF: https://github.com/llvm/llvm-project/commit/c64c04bbaadbc35e265f12644b45787d6d077587.diff

LOG: Clean up cuda-runtime-wrappers API.

Do not return error code, instead return created resource handles or void. Error reporting is done by the library function.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D84660

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
    mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
    mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
    mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index 37b056263ab4..14011e08de02 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -39,7 +39,7 @@ static constexpr const char *kGpuModuleLoadName = "mgpuModuleLoad";
 static constexpr const char *kGpuModuleGetFunctionName =
     "mgpuModuleGetFunction";
 static constexpr const char *kGpuLaunchKernelName = "mgpuLaunchKernel";
-static constexpr const char *kGpuGetStreamHelperName = "mgpuGetStreamHelper";
+static constexpr const char *kGpuStreamCreateName = "mgpuStreamCreate";
 static constexpr const char *kGpuStreamSynchronizeName =
     "mgpuStreamSynchronize";
 static constexpr const char *kGpuMemHostRegisterName = "mgpuMemHostRegister";
@@ -100,12 +100,6 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
         getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
   }
 
-  LLVM::LLVMType getGpuRuntimeResultType() {
-    // This is declared as an enum in both CUDA and ROCm (HIP), but helpers
-    // use i32.
-    return getInt32Type();
-  }
-
   // Allocate a void pointer on the stack.
   Value allocatePointer(OpBuilder &builder, Location loc) {
     auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
@@ -168,27 +162,21 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::declareGpuRuntimeFunctions(
   if (!module.lookupSymbol(kGpuModuleLoadName)) {
     builder.create<LLVM::LLVMFuncOp>(
         loc, kGpuModuleLoadName,
-        LLVM::LLVMType::getFunctionTy(
-            getGpuRuntimeResultType(),
-            {
-                getPointerPointerType(), /* CUmodule *module */
-                getPointerType()         /* void *cubin */
-            },
-            /*isVarArg=*/false));
+        LLVM::LLVMType::getFunctionTy(getPointerType(),
+                                      {getPointerType()}, /* void *cubin */
+                                      /*isVarArg=*/false));
   }
   if (!module.lookupSymbol(kGpuModuleGetFunctionName)) {
     // The helper uses void* instead of CUDA's opaque CUmodule and
     // CUfunction, or ROCm (HIP)'s opaque hipModule_t and hipFunction_t.
     builder.create<LLVM::LLVMFuncOp>(
         loc, kGpuModuleGetFunctionName,
-        LLVM::LLVMType::getFunctionTy(
-            getGpuRuntimeResultType(),
-            {
-                getPointerPointerType(), /* void **function */
-                getPointerType(),        /* void *module */
-                getPointerType()         /* char *name */
-            },
-            /*isVarArg=*/false));
+        LLVM::LLVMType::getFunctionTy(getPointerType(),
+                                      {
+                                          getPointerType(), /* void *module */
+                                          getPointerType()  /* char *name   */
+                                      },
+                                      /*isVarArg=*/false));
   }
   if (!module.lookupSymbol(kGpuLaunchKernelName)) {
     // Other than the CUDA or ROCm (HIP) api, the wrappers use uintptr_t to
@@ -198,7 +186,7 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::declareGpuRuntimeFunctions(
     builder.create<LLVM::LLVMFuncOp>(
         loc, kGpuLaunchKernelName,
         LLVM::LLVMType::getFunctionTy(
-            getGpuRuntimeResultType(),
+            getVoidType(),
             {
                 getPointerType(),        /* void* f */
                 getIntPtrType(),         /* intptr_t gridXDim */
@@ -214,18 +202,18 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::declareGpuRuntimeFunctions(
             },
             /*isVarArg=*/false));
   }
-  if (!module.lookupSymbol(kGpuGetStreamHelperName)) {
+  if (!module.lookupSymbol(kGpuStreamCreateName)) {
     // Helper function to get the current GPU compute stream. Uses void*
     // instead of CUDA's opaque CUstream, or ROCm (HIP)'s opaque hipStream_t.
     builder.create<LLVM::LLVMFuncOp>(
-        loc, kGpuGetStreamHelperName,
+        loc, kGpuStreamCreateName,
         LLVM::LLVMType::getFunctionTy(getPointerType(), /*isVarArg=*/false));
   }
   if (!module.lookupSymbol(kGpuStreamSynchronizeName)) {
     builder.create<LLVM::LLVMFuncOp>(
         loc, kGpuStreamSynchronizeName,
-        LLVM::LLVMType::getFunctionTy(getGpuRuntimeResultType(),
-                                      getPointerType() /* CUstream stream */,
+        LLVM::LLVMType::getFunctionTy(getVoidType(),
+                                      {getPointerType()}, /* void *stream */
                                       /*isVarArg=*/false));
   }
   if (!module.lookupSymbol(kGpuMemHostRegisterName)) {
@@ -365,17 +353,13 @@ Value GpuLaunchFuncToGpuRuntimeCallsPass::generateKernelNameConstant(
 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
 //
 // %0 = call %binarygetter
-// %1 = alloca sizeof(void*)
-// call %moduleLoad(%2, %1)
-// %2 = alloca sizeof(void*)
-// %3 = load %1
-// %4 = <see generateKernelNameConstant>
-// call %moduleGetFunction(%2, %3, %4)
-// %5 = call %getStreamHelper()
-// %6 = load %2
-// %7 = <see setupParamsArray>
-// call %launchKernel(%6, <launchOp operands 0..5>, 0, %5, %7, nullptr)
-// call %streamSynchronize(%5)
+// %1 = call %moduleLoad(%0)
+// %2 = <see generateKernelNameConstant>
+// %3 = call %moduleGetFunction(%1, %2)
+// %4 = call %streamCreate()
+// %5 = <see setupParamsArray>
+// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
+// call %streamSynchronize(%4)
 void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls(
     mlir::gpu::LaunchFuncOp launchOp) {
   OpBuilder builder(launchOp);
@@ -405,36 +389,30 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls(
 
   // Emit the load module call to load the module data. Error checking is done
   // in the called helper function.
-  auto gpuModule = allocatePointer(builder, loc);
   auto gpuModuleLoad =
       getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuModuleLoadName);
-  builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getGpuRuntimeResultType()},
-                               builder.getSymbolRefAttr(gpuModuleLoad),
-                               ArrayRef<Value>{gpuModule, data});
+  auto module = builder.create<LLVM::CallOp>(
+      loc, ArrayRef<Type>{getPointerType()},
+      builder.getSymbolRefAttr(gpuModuleLoad), ArrayRef<Value>{data});
   // Get the function from the module. The name corresponds to the name of
   // the kernel function.
-  auto gpuOwningModuleRef =
-      builder.create<LLVM::LoadOp>(loc, getPointerType(), gpuModule);
   auto kernelName = generateKernelNameConstant(
       launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, builder);
-  auto gpuFunction = allocatePointer(builder, loc);
   auto gpuModuleGetFunction =
       getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuModuleGetFunctionName);
-  builder.create<LLVM::CallOp>(
-      loc, ArrayRef<Type>{getGpuRuntimeResultType()},
+  auto function = builder.create<LLVM::CallOp>(
+      loc, ArrayRef<Type>{getPointerType()},
       builder.getSymbolRefAttr(gpuModuleGetFunction),
-      ArrayRef<Value>{gpuFunction, gpuOwningModuleRef, kernelName});
+      ArrayRef<Value>{module.getResult(0), kernelName});
   // Grab the global stream needed for execution.
-  auto gpuGetStreamHelper =
-      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuGetStreamHelperName);
-  auto gpuStream = builder.create<LLVM::CallOp>(
+  auto gpuStreamCreate =
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuStreamCreateName);
+  auto stream = builder.create<LLVM::CallOp>(
       loc, ArrayRef<Type>{getPointerType()},
-      builder.getSymbolRefAttr(gpuGetStreamHelper), ArrayRef<Value>{});
+      builder.getSymbolRefAttr(gpuStreamCreate), ArrayRef<Value>{});
   // Invoke the function with required arguments.
   auto gpuLaunchKernel =
       getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuLaunchKernelName);
-  auto gpuFunctionRef =
-      builder.create<LLVM::LoadOp>(loc, getPointerType(), gpuFunction);
   auto paramsArray = setupParamsArray(launchOp, builder);
   if (!paramsArray) {
     launchOp.emitOpError() << "cannot pass given parameters to the kernel";
@@ -443,21 +421,21 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls(
   auto nullpointer =
       builder.create<LLVM::IntToPtrOp>(loc, getPointerPointerType(), zero);
   builder.create<LLVM::CallOp>(
-      loc, ArrayRef<Type>{getGpuRuntimeResultType()},
+      loc, ArrayRef<Type>{getVoidType()},
       builder.getSymbolRefAttr(gpuLaunchKernel),
-      ArrayRef<Value>{gpuFunctionRef, launchOp.getOperand(0),
+      ArrayRef<Value>{function.getResult(0), launchOp.getOperand(0),
                       launchOp.getOperand(1), launchOp.getOperand(2),
                       launchOp.getOperand(3), launchOp.getOperand(4),
                       launchOp.getOperand(5), zero, /* sharedMemBytes */
-                      gpuStream.getResult(0),       /* stream */
+                      stream.getResult(0),          /* stream */
                       paramsArray,                  /* kernel params */
                       nullpointer /* extra */});
   // Sync on the stream to make it synchronous.
   auto gpuStreamSync =
       getOperation().lookupSymbol<LLVM::LLVMFuncOp>(kGpuStreamSynchronizeName);
-  builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getGpuRuntimeResultType()},
+  builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
                                builder.getSymbolRefAttr(gpuStreamSync),
-                               ArrayRef<Value>(gpuStream.getResult(0)));
+                               ArrayRef<Value>(stream.getResult(0)));
   launchOp.erase();
 }
 

diff  --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
index a3381465ebf2..bdcde0be60c2 100644
--- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
@@ -20,13 +20,11 @@ module attributes {gpu.container_module} {
 
     // CHECK: %[[addressof:.*]] = llvm.mlir.addressof @[[global]]
     // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index)
-    // CHECK: %[[binary_ptr:.*]] = llvm.getelementptr %[[addressof]][%[[c0]], %[[c0]]]
+    // CHECK: %[[binary:.*]] = llvm.getelementptr %[[addressof]][%[[c0]], %[[c0]]]
     // CHECK-SAME: -> !llvm<"i8*">
-    // CHECK: %[[module_ptr:.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
-    // CHECK: llvm.call @mgpuModuleLoad(%[[module_ptr]], %[[binary_ptr]]) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
-    // CHECK: %[[func_ptr:.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
-    // CHECK: llvm.call @mgpuModuleGetFunction(%[[func_ptr]], {{.*}}, {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
-    // CHECK: llvm.call @mgpuGetStreamHelper
+    // CHECK: %[[module:.*]] = llvm.call @mgpuModuleLoad(%[[binary]]) : (!llvm<"i8*">) -> !llvm<"i8*">
+    // CHECK: %[[func:.*]] = llvm.call @mgpuModuleGetFunction(%[[module]], {{.*}}) : (!llvm<"i8*">, !llvm<"i8*">) -> !llvm<"i8*">
+    // CHECK: llvm.call @mgpuStreamCreate
     // CHECK: llvm.call @mgpuLaunchKernel
     // CHECK: llvm.call @mgpuStreamSynchronize
     "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel_module::@kernel }

diff  --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 2b71eb34703b..8e2dc029fa9f 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -21,54 +21,50 @@
 
 #include "cuda.h"
 
-namespace {
-int32_t reportErrorIfAny(CUresult result, const char *where) {
-  if (result != CUDA_SUCCESS) {
-    llvm::errs() << "CUDA failed with " << result << " in " << where << "\n";
-  }
-  return result;
+#define CUDA_REPORT_IF_ERROR(expr)                                             \
+  [](CUresult result) {                                                        \
+    if (!result)                                                               \
+      return;                                                                  \
+    const char *name = nullptr;                                                \
+    cuGetErrorName(result, &name);                                             \
+    if (!name)                                                                 \
+      name = "<unknown>";                                                      \
+    llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n";        \
+  }(expr)
+
+extern "C" CUmodule mgpuModuleLoad(void *data) {
+  CUmodule module = nullptr;
+  CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
+  return module;
 }
-} // anonymous namespace
 
-extern "C" int32_t mgpuModuleLoad(void **module, void *data) {
-  int32_t err = reportErrorIfAny(
-      cuModuleLoadData(reinterpret_cast<CUmodule *>(module), data),
-      "ModuleLoad");
-  return err;
-}
-
-extern "C" int32_t mgpuModuleGetFunction(void **function, void *module,
-                                         const char *name) {
-  return reportErrorIfAny(
-      cuModuleGetFunction(reinterpret_cast<CUfunction *>(function),
-                          reinterpret_cast<CUmodule>(module), name),
-      "GetFunction");
+extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) {
+  CUfunction function = nullptr;
+  CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
+  return function;
 }
 
 // The wrapper uses intptr_t instead of CUDA's unsigned int to match
 // the type of MLIR's index type. This avoids the need for casts in the
 // generated MLIR code.
-extern "C" int32_t mgpuLaunchKernel(void *function, intptr_t gridX,
-                                    intptr_t gridY, intptr_t gridZ,
-                                    intptr_t blockX, intptr_t blockY,
-                                    intptr_t blockZ, int32_t smem, void *stream,
-                                    void **params, void **extra) {
-  return reportErrorIfAny(
-      cuLaunchKernel(reinterpret_cast<CUfunction>(function), gridX, gridY,
-                     gridZ, blockX, blockY, blockZ, smem,
-                     reinterpret_cast<CUstream>(stream), params, extra),
-      "LaunchKernel");
+extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
+                                 intptr_t gridY, intptr_t gridZ,
+                                 intptr_t blockX, intptr_t blockY,
+                                 intptr_t blockZ, int32_t smem, CUstream stream,
+                                 void **params, void **extra) {
+  CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX,
+                                      blockY, blockZ, smem, stream, params,
+                                      extra));
 }
 
-extern "C" void *mgpuGetStreamHelper() {
-  CUstream stream;
-  reportErrorIfAny(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "StreamCreate");
+extern "C" CUstream mgpuStreamCreate() {
+  CUstream stream = nullptr;
+  CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
   return stream;
 }
 
-extern "C" int32_t mgpuStreamSynchronize(void *stream) {
-  return reportErrorIfAny(
-      cuStreamSynchronize(reinterpret_cast<CUstream>(stream)), "StreamSync");
+extern "C" void mgpuStreamSynchronize(CUstream stream) {
+  CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream));
 }
 
 /// Helper functions for writing mlir example code
@@ -76,17 +72,16 @@ extern "C" int32_t mgpuStreamSynchronize(void *stream) {
 // Allows to register byte array with the CUDA runtime. Helpful until we have
 // transfer functions implemented.
 extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
-  reportErrorIfAny(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0),
-                   "MemHostRegister");
+  CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
 }
 
 // Allows to register a MemRef with the CUDA runtime. Initializes array with
 // value. Helpful until we have transfer functions implemented.
 template <typename T>
-void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &mem_ref, T value) {
-  llvm::SmallVector<int64_t, 4> denseStrides(mem_ref.rank);
-  llvm::ArrayRef<int64_t> sizes(mem_ref.sizes, mem_ref.rank);
-  llvm::ArrayRef<int64_t> strides(mem_ref.strides, mem_ref.rank);
+void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &memRef, T value) {
+  llvm::SmallVector<int64_t, 4> denseStrides(memRef.rank);
+  llvm::ArrayRef<int64_t> sizes(memRef.sizes, memRef.rank);
+  llvm::ArrayRef<int64_t> strides(memRef.strides, memRef.rank);
 
   std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
                    std::multiplies<int64_t>());
@@ -98,17 +93,17 @@ void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &mem_ref, T value) {
   denseStrides.back() = 1;
   assert(strides == llvm::makeArrayRef(denseStrides));
 
-  auto *pointer = mem_ref.data + mem_ref.offset;
+  auto *pointer = memRef.data + memRef.offset;
   std::fill_n(pointer, count, value);
   mgpuMemHostRegister(pointer, count * sizeof(T));
 }
 
 extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
-  UnrankedMemRefType<float> mem_ref = {rank, ptr};
-  mgpuMemHostRegisterMemRef(DynamicMemRefType<float>(mem_ref), 1.23f);
+  UnrankedMemRefType<float> memRef = {rank, ptr};
+  mgpuMemHostRegisterMemRef(DynamicMemRefType<float>(memRef), 1.23f);
 }
 
 extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
-  UnrankedMemRefType<int32_t> mem_ref = {rank, ptr};
-  mgpuMemHostRegisterMemRef(DynamicMemRefType<int32_t>(mem_ref), 123);
+  UnrankedMemRefType<int32_t> memRef = {rank, ptr};
+  mgpuMemHostRegisterMemRef(DynamicMemRefType<int32_t>(memRef), 123);
 }

diff  --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
index f49e6c91ea65..b97ce695ac42 100644
--- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
@@ -21,56 +21,52 @@
 
 #include "hip/hip_runtime.h"
 
-namespace {
-int32_t reportErrorIfAny(hipError_t result, const char *where) {
-  if (result != hipSuccess) {
-    llvm::errs() << "HIP failed with " << result << " in " << where << "\n";
-  }
-  return result;
+#define HIP_REPORT_IF_ERROR(expr)                                              \
+  [](hipError_t result) {                                                      \
+    if (!result)                                                               \
+      return;                                                                  \
+    const char *name = nullptr;                                                \
+    hipGetErrorName(result, &name);                                            \
+    if (!name)                                                                 \
+      name = "<unknown>";                                                      \
+    llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n";        \
+  }(expr)
+
+extern "C" hipModule_t mgpuModuleLoad(void *data) {
+  hipModule_t module = nullptr;
+  HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
+  return module;
 }
-} // anonymous namespace
 
-extern "C" int32_t mgpuModuleLoad(void **module, void *data) {
-  int32_t err = reportErrorIfAny(
-      hipModuleLoadData(reinterpret_cast<hipModule_t *>(module), data),
-      "ModuleLoad");
-  return err;
-}
-
-extern "C" int32_t mgpuModuleGetFunction(void **function, void *module,
-                                         const char *name) {
-  return reportErrorIfAny(
-      hipModuleGetFunction(reinterpret_cast<hipFunction_t *>(function),
-                           reinterpret_cast<hipModule_t>(module), name),
-      "GetFunction");
+extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module,
+                                               const char *name) {
+  hipFunction_t function = nullptr;
+  HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name));
+  return function;
 }
 
 // The wrapper uses intptr_t instead of ROCM's unsigned int to match
 // the type of MLIR's index type. This avoids the need for casts in the
 // generated MLIR code.
-extern "C" int32_t mgpuLaunchKernel(void *function, intptr_t gridX,
-                                    intptr_t gridY, intptr_t gridZ,
-                                    intptr_t blockX, intptr_t blockY,
-                                    intptr_t blockZ, int32_t smem, void *stream,
-                                    void **params, void **extra) {
-  return reportErrorIfAny(
-      hipModuleLaunchKernel(reinterpret_cast<hipFunction_t>(function), gridX,
-                            gridY, gridZ, blockX, blockY, blockZ, smem,
-                            reinterpret_cast<hipStream_t>(stream), params,
-                            extra),
-      "LaunchKernel");
+extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
+                                 intptr_t gridY, intptr_t gridZ,
+                                 intptr_t blockX, intptr_t blockY,
+                                 intptr_t blockZ, int32_t smem,
+                                 hipStream_t stream, void **params,
+                                 void **extra) {
+  HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
+                                            blockX, blockY, blockZ, smem,
+                                            stream, params, extra));
 }
 
-extern "C" void *mgpuGetStreamHelper() {
-  hipStream_t stream;
-  reportErrorIfAny(hipStreamCreate(&stream), "StreamCreate");
+extern "C" void *mgpuStreamCreate() {
+  hipStream_t stream = nullptr;
+  HIP_REPORT_IF_ERROR(hipStreamCreate(&stream));
   return stream;
 }
 
-extern "C" int32_t mgpuStreamSynchronize(void *stream) {
-  return reportErrorIfAny(
-      hipStreamSynchronize(reinterpret_cast<hipStream_t>(stream)),
-      "StreamSync");
+extern "C" void mgpuStreamSynchronize(hipStream_t stream) {
+  return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream));
 }
 
 /// Helper functions for writing mlir example code
@@ -78,8 +74,8 @@ extern "C" int32_t mgpuStreamSynchronize(void *stream) {
 // Allows to register byte array with the ROCM runtime. Helpful until we have
 // transfer functions implemented.
 extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
-  reportErrorIfAny(hipHostRegister(ptr, sizeBytes, /*flags=*/0),
-                   "MemHostRegister");
+  HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0),
+                      "MemHostRegister");
 }
 
 // Allows to register a MemRef with the ROCM runtime. Initializes array with
@@ -120,8 +116,8 @@ extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
 
 template <typename T>
 void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) {
-  reportErrorIfAny(hipSetDevice(0), "hipSetDevice");
-  reportErrorIfAny(
+  HIP_REPORT_IF_ERROR(hipSetDevice(0), "hipSetDevice");
+  HIP_REPORT_IF_ERROR(
       hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0),
       "hipHostGetDevicePointer");
 }


        


More information about the Mlir-commits mailing list