[Mlir-commits] [mlir] 2fd6403 - [mlir][gpu] Introduce mlir-rocm-runner.

Wen-Heng Chung llvmlistbot at llvm.org
Fri Jun 5 07:46:53 PDT 2020


Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T09:46:39-05:00
New Revision: 2fd6403a6d7a81c3c9d2676643bfeda042359d3c

URL: https://github.com/llvm/llvm-project/commit/2fd6403a6d7a81c3c9d2676643bfeda042359d3c
DIFF: https://github.com/llvm/llvm-project/commit/2fd6403a6d7a81c3c9d2676643bfeda042359d3c.diff

LOG: [mlir][gpu] Introduce mlir-rocm-runner.

Summary:
`mlir-rocm-runner` is introduced in this commit to execute GPU modules on ROCm
platform. A small wrapper to encapsulate ROCm's HIP runtime API is also inside
the commit.

Due to behavior of ROCm, raw pointers inside memrefs passed to `gpu.launch`
must be modified on the host side to properly capture the pointer values
addressable on the GPU.

LLVM MC is used to assemble AMD GCN ISA coming out from
`ConvertGPUKernelToBlobPass` to binary form, and LLD is used to produce a shared
ELF object which could be loaded by ROCm HIP runtime.

gfx900 is the default target be used right now, although it could be altered via
an option in `mlir-rocm-runner`. Future revisions may consider using ROCm Agent
Enumerator to detect the right target on the system.

Notice AMDGPU Code Object V2 is used in this revision. Future enhancements may
upgrade to AMDGPU Code Object V3.

Bitcode libraries in ROCm-Device-Libs, which implements math routines exposed in
`rocdl` dialect are not yet linked, and is left as a TODO in the logic.

Reviewers: herhut

Subscribers: mgorny, tpr, dexonsmith, mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, csigg, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits

Tags: #mlir, #llvm

Differential Revision: https://reviews.llvm.org/D80676

Added: 
    mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
    mlir/test/mlir-rocm-runner/lit.local.cfg
    mlir/test/mlir-rocm-runner/two-modules.mlir
    mlir/test/mlir-rocm-runner/vecadd.mlir
    mlir/tools/mlir-rocm-runner/CMakeLists.txt
    mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp
    mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp

Modified: 
    mlir/CMakeLists.txt
    mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
    mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
    mlir/test/CMakeLists.txt
    mlir/test/lit.cfg.py
    mlir/test/lit.site.cfg.py.in
    mlir/tools/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 0cf1e8d44516..291a7f26d3cc 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -41,6 +41,7 @@ endif()
 add_definitions(-DMLIR_ROCM_CONVERSIONS_ENABLED=${MLIR_ROCM_CONVERSIONS_ENABLED})
 
 set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner")
+set(MLIR_ROCM_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir ROCm runner")
 set(MLIR_VULKAN_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir Vulkan runner")
 
 option(MLIR_INCLUDE_TESTS

diff  --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
index b022ebc042c9..56bc5f2c2c4c 100644
--- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
@@ -43,7 +43,8 @@ using LoweringCallback =
 /// instead uses a small wrapper library that exports a stable and conveniently
 /// typed ABI on top of GPU runtimes such as CUDA or ROCm (HIP).
 std::unique_ptr<OperationPass<ModuleOp>>
-createConvertGpuLaunchFuncToGpuRuntimeCallsPass();
+createConvertGpuLaunchFuncToGpuRuntimeCallsPass(
+    StringRef gpuBinaryAnnotation = "");
 
 /// Creates a pass to convert kernel functions into GPU target object blobs.
 ///

diff  --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index 7bd388803d96..5f922c84a9a1 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -123,6 +123,11 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
   void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
 
 public:
+  GpuLaunchFuncToGpuRuntimeCallsPass() = default;
+  GpuLaunchFuncToGpuRuntimeCallsPass(StringRef gpuBinaryAnnotation) {
+    this->gpuBinaryAnnotation = gpuBinaryAnnotation.str();
+  }
+
   // Run the dialect converter on the module.
   void runOnOperation() override {
     // Cache the LLVMDialect for the current module.
@@ -457,6 +462,8 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls(
 }
 
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-mlir::createConvertGpuLaunchFuncToGpuRuntimeCallsPass() {
-  return std::make_unique<GpuLaunchFuncToGpuRuntimeCallsPass>();
+mlir::createConvertGpuLaunchFuncToGpuRuntimeCallsPass(
+    StringRef gpuBinaryAnnotation) {
+  return std::make_unique<GpuLaunchFuncToGpuRuntimeCallsPass>(
+      gpuBinaryAnnotation);
 }

diff  --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index b8d2a6e05594..52756b6aae7f 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -13,8 +13,9 @@ set(MLIR_DIALECT_LINALG_INTEGRATION_TEST_LIB_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTOR
 set(MLIR_RUNNER_UTILS_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
 
 # Passed to lit.site.cfg.py.in to set up the path where to find the libraries
-# for the mlir cuda runner tests.
+# for the mlir cuda / rocm / vulkan runner tests.
 set(MLIR_CUDA_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
+set(MLIR_ROCM_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
 set(MLIR_VULKAN_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
 
 configure_lit_site_cfg(
@@ -64,6 +65,12 @@ if(MLIR_CUDA_RUNNER_ENABLED)
   )
 endif()
 
+if(MLIR_ROCM_RUNNER_ENABLED)
+  list(APPEND MLIR_TEST_DEPENDS
+    mlir-rocm-runner
+  )
+endif()
+
 if(MLIR_VULKAN_RUNNER_ENABLED)
   list(APPEND MLIR_TEST_DEPENDS
     mlir-vulkan-runner

diff  --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index e78c82815b15..7e8778fc4e83 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -70,6 +70,7 @@
     ToolSubst('%cuda_wrapper_library_dir', config.cuda_wrapper_library_dir, unresolved='ignore'),
     ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'),
     ToolSubst('%mlir_runner_utils_dir', config.mlir_runner_utils_dir, unresolved='ignore'),
+    ToolSubst('%rocm_wrapper_library_dir', config.rocm_wrapper_library_dir, unresolved='ignore'),
     ToolSubst('%vulkan_wrapper_library_dir', config.vulkan_wrapper_library_dir, unresolved='ignore')
 ])
 

diff  --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index e07acf4d21a8..b75518611cf2 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -39,6 +39,8 @@ config.run_cuda_tests = @MLIR_CUDA_CONVERSIONS_ENABLED@
 config.cuda_wrapper_library_dir = "@MLIR_CUDA_WRAPPER_LIBRARY_DIR@"
 config.enable_cuda_runner = @MLIR_CUDA_RUNNER_ENABLED@
 config.run_rocm_tests = @MLIR_ROCM_CONVERSIONS_ENABLED@
+config.rocm_wrapper_library_dir = "@MLIR_ROCM_WRAPPER_LIBRARY_DIR@"
+config.enable_rocm_runner = @MLIR_ROCM_RUNNER_ENABLED@
 config.vulkan_wrapper_library_dir = "@MLIR_VULKAN_WRAPPER_LIBRARY_DIR@"
 config.enable_vulkan_runner = @MLIR_VULKAN_RUNNER_ENABLED@
 

diff  --git a/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir b/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
new file mode 100644
index 000000000000..433fd859dfb3
--- /dev/null
+++ b/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-rocm-runner %s --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+func @other_func(%arg0 : f32, %arg1 : memref<?xf32>) {
+  %cst = constant 1 : index
+  %cst2 = dim %arg1, 0 : memref<?xf32>
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst)
+             threads(%tx, %ty, %tz) in (%block_x = %cst2, %block_y = %cst, %block_z = %cst) {
+    store %arg0, %arg1[%tx] : memref<?xf32>
+    gpu.terminator
+  }
+  return
+}
+
+// CHECK: [1, 1, 1, 1, 1]
+func @main() {
+  %arg0 = alloc() : memref<5xf32>
+  %21 = constant 5 : i32
+  %22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
+  %cast = memref_cast %22 : memref<?xf32> to memref<*xf32>
+  call @mgpuMemHostRegisterFloat(%cast) : (memref<*xf32>) -> ()
+  %23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
+  call @print_memref_f32(%23) : (memref<*xf32>) -> ()
+  %24 = constant 1.0 : f32
+  %25 = call @mgpuMemGetDeviceMemRef1dFloat(%22) : (memref<?xf32>) -> (memref<?xf32>)
+  call @other_func(%24, %25) : (f32, memref<?xf32>) -> ()
+  call @print_memref_f32(%23) : (memref<*xf32>) -> ()
+  return
+}
+
+func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
+func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
+func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/test/mlir-rocm-runner/lit.local.cfg b/mlir/test/mlir-rocm-runner/lit.local.cfg
new file mode 100644
index 000000000000..0ced06979486
--- /dev/null
+++ b/mlir/test/mlir-rocm-runner/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_rocm_runner:
+  config.unsupported = True

diff  --git a/mlir/test/mlir-rocm-runner/two-modules.mlir b/mlir/test/mlir-rocm-runner/two-modules.mlir
new file mode 100644
index 000000000000..598ac8110775
--- /dev/null
+++ b/mlir/test/mlir-rocm-runner/two-modules.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-rocm-runner %s --print-ir-after-all --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --dump-input=always
+
+// CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
+func @main() {
+  %arg = alloc() : memref<13xi32>
+  %dst = memref_cast %arg : memref<13xi32> to memref<?xi32>
+  %one = constant 1 : index
+  %sx = dim %dst, 0 : memref<?xi32>
+  %cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
+  call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
+  %dst_device = call @mgpuMemGetDeviceMemRef1dInt32(%dst) : (memref<?xi32>) -> (memref<?xi32>)
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
+             threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
+    %t0 = index_cast %tx : index to i32
+    store %t0, %dst_device[%tx] : memref<?xi32>
+    gpu.terminator
+  }
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
+             threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
+    %t0 = index_cast %tx : index to i32
+    store %t0, %dst_device[%tx] : memref<?xi32>
+    gpu.terminator
+  }
+  call @print_memref_i32(%cast_dst) : (memref<*xi32>) -> ()
+  return
+}
+
+func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
+func @mgpuMemGetDeviceMemRef1dInt32(%ptr : memref<?xi32>) -> (memref<?xi32>)
+func @print_memref_i32(%ptr : memref<*xi32>)

diff  --git a/mlir/test/mlir-rocm-runner/vecadd.mlir b/mlir/test/mlir-rocm-runner/vecadd.mlir
new file mode 100644
index 000000000000..57195e2d51b3
--- /dev/null
+++ b/mlir/test/mlir-rocm-runner/vecadd.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-rocm-runner %s --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+func @vecadd(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>, %arg2 : memref<?xf32>) {
+  %cst = constant 1 : index
+  %cst2 = dim %arg0, 0 : memref<?xf32>
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst)
+             threads(%tx, %ty, %tz) in (%block_x = %cst2, %block_y = %cst, %block_z = %cst) {
+    %a = load %arg0[%tx] : memref<?xf32>
+    %b = load %arg1[%tx] : memref<?xf32>
+    %c = addf %a, %b : f32
+    store %c, %arg2[%tx] : memref<?xf32>
+    gpu.terminator
+  }
+  return
+}
+
+// CHECK: [2.46, 2.46, 2.46, 2.46, 2.46]
+func @main() {
+  %0 = alloc() : memref<5xf32>
+  %1 = alloc() : memref<5xf32>
+  %2 = alloc() : memref<5xf32>
+  %3 = memref_cast %0 : memref<5xf32> to memref<?xf32>
+  %4 = memref_cast %1 : memref<5xf32> to memref<?xf32>
+  %5 = memref_cast %2 : memref<5xf32> to memref<?xf32>
+  %6 = memref_cast %3 : memref<?xf32> to memref<*xf32>
+  %7 = memref_cast %4 : memref<?xf32> to memref<*xf32>
+  %8 = memref_cast %5 : memref<?xf32> to memref<*xf32>
+  call @mgpuMemHostRegisterFloat(%6) : (memref<*xf32>) -> ()
+  call @mgpuMemHostRegisterFloat(%7) : (memref<*xf32>) -> ()
+  call @mgpuMemHostRegisterFloat(%8) : (memref<*xf32>) -> ()
+  %9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref<?xf32>) -> (memref<?xf32>)
+  %10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref<?xf32>) -> (memref<?xf32>)
+  %11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref<?xf32>) -> (memref<?xf32>)
+
+  call @vecadd(%9, %10, %11) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
+  call @print_memref_f32(%8) : (memref<*xf32>) -> ()
+  return
+}
+
+func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
+func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
+func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index f01648bffec3..e8f61633c92b 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(mlir-cuda-runner)
 add_subdirectory(mlir-cpu-runner)
 add_subdirectory(mlir-linalg-ods-gen)
 add_subdirectory(mlir-opt)
+add_subdirectory(mlir-rocm-runner)
 add_subdirectory(mlir-translate)
 add_subdirectory(mlir-vulkan-runner)
 add_subdirectory(mlir-shlib)

diff  --git a/mlir/tools/mlir-rocm-runner/CMakeLists.txt b/mlir/tools/mlir-rocm-runner/CMakeLists.txt
new file mode 100644
index 000000000000..7187fb7fc205
--- /dev/null
+++ b/mlir/tools/mlir-rocm-runner/CMakeLists.txt
@@ -0,0 +1,113 @@
+set(LLVM_OPTIONAL_SOURCES
+  rocm-runtime-wrappers.cpp
+  mlir-rocm-runner.cpp
+  )
+
+if(MLIR_ROCM_RUNNER_ENABLED)
+  if (NOT ("AMDGPU" IN_LIST LLVM_TARGETS_TO_BUILD))
+    message(SEND_ERROR
+      "Building the mlir rocm runner requires the AMDGPU backend")
+  endif()
+
+  # Ensure lld is enabled.
+  if (NOT "lld" IN_LIST LLVM_ENABLE_PROJECTS)
+    message(SEND_ERROR "lld is not enabled. Please revise LLVM_ENABLE_PROJECTS")
+  endif()
+
+  # lld header files.
+  include_directories(${MLIR_SOURCE_DIR}/../lld/include)
+
+  # Configure ROCm support.
+  if (NOT DEFINED ROCM_PATH)
+    if (NOT DEFINED ENV{ROCM_PATH})
+      set(ROCM_PATH "/opt/rocm" CACHE PATH "Path to which ROCm has been installed")
+    else()
+      set(ROCM_PATH $ENV{ROCM_PATH} CACHE PATH "Path to which ROCm has been installed")
+    endif()
+    set(HIP_PATH "${ROCM_PATH}/hip" CACHE PATH " Path to which HIP has been installed")
+  endif()
+  set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH})
+  find_package(HIP)
+  if (NOT HIP_FOUND)
+    message(SEND_ERROR "Build the mlir rocm runner requires a working ROCm and HIP install")
+  else()
+    message(STATUS "ROCm HIP version: ${HIP_VERSION}")
+  endif()
+
+  # Locate HIP runtime library.
+  find_library(ROCM_RUNTIME_LIBRARY hip_hcc
+               PATHS "${HIP_PATH}/lib")
+  if (NOT ROCM_RUNTIME_LIBRARY)
+    message(SEND_ERROR "Could not locate ROCm HIP runtime library")
+  else()
+    message(STATUS "ROCm HIP runtime lib: ${ROCM_RUNTIME_LIBRARY}")
+  endif()
+
+  # Set HIP compile-time flags.
+  add_definitions(-D__HIP_PLATFORM_HCC__)
+
+  add_llvm_library(rocm-runtime-wrappers SHARED
+    rocm-runtime-wrappers.cpp
+  )
+  target_include_directories(rocm-runtime-wrappers
+    PRIVATE
+    "${HIP_PATH}/../include"
+    "${HIP_PATH}/include"
+    LLVMSupport
+  )
+  target_link_libraries(rocm-runtime-wrappers
+    PUBLIC
+    LLVMSupport
+    ${ROCM_RUNTIME_LIBRARY}
+  )
+
+  get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+  get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
+  set(LIBS
+    ${dialect_libs}
+    ${conversion_libs}
+    lldCommon
+    lldDriver
+    lldELF
+    LLVMCore
+    LLVMLTO
+    LLVMMC
+    LLVMMCParser
+    LLVMOption
+    LLVMSupport
+    MLIRJitRunner
+    MLIRAnalysis
+    MLIREDSC
+    MLIRExecutionEngine
+    MLIRIR
+    MLIRParser
+    MLIRROCDLIR
+    MLIRSupport
+    MLIRTargetLLVMIR
+    MLIRTargetROCDLIR
+    MLIRTransforms
+    MLIRTranslation
+    ${ROCM_RUNTIME_LIBRARY}
+  )
+
+  # Manually expand the target library, since our MLIR libraries
+  # aren't plugged into the LLVM dependency tracking. If we don't
+  # do this then we can't insert the CodeGen library after ourselves
+  llvm_expand_pseudo_components(TARGET_LIBS AllTargetsCodeGens AllTargetsAsmParsers)
+  # Prepend LLVM in front of every target, this is how the library
+  # are named with CMake
+  SET(targets_to_link)
+  FOREACH(t ${TARGET_LIBS})
+    LIST(APPEND targets_to_link "LLVM${t}")
+  ENDFOREACH(t)
+
+  add_llvm_tool(mlir-rocm-runner
+    mlir-rocm-runner.cpp
+
+    DEPENDS
+    rocm-runtime-wrappers
+    )
+  llvm_update_compile_flags(mlir-rocm-runner)
+  target_link_libraries(mlir-rocm-runner PRIVATE ${LIBS} ${targets_to_link})
+
+endif()

diff  --git a/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp b/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp
new file mode 100644
index 000000000000..19bb376b18d0
--- /dev/null
+++ b/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp
@@ -0,0 +1,248 @@
+//===- mlir-rocm-runner.cpp - MLIR ROCM Execution Driver-------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is a command line utility that executes an MLIR file on the GPU by
+// translating MLIR to ROCDL/LLVM IR before JIT-compiling and executing the
+// latter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/STLExtras.h"
+
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/ExecutionEngine/JitRunner.h"
+#include "mlir/ExecutionEngine/OptUtils.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Target/ROCDLIR.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+
+// MC headers.
+#include "llvm/MC/MCAsmBackend.h"
+#include "llvm/MC/MCAsmInfo.h"
+#include "llvm/MC/MCCodeEmitter.h"
+#include "llvm/MC/MCContext.h"
+#include "llvm/MC/MCInstPrinter.h"
+#include "llvm/MC/MCInstrInfo.h"
+#include "llvm/MC/MCObjectFileInfo.h"
+#include "llvm/MC/MCObjectWriter.h"
+#include "llvm/MC/MCParser/AsmLexer.h"
+#include "llvm/MC/MCParser/MCTargetAsmParser.h"
+#include "llvm/MC/MCRegisterInfo.h"
+#include "llvm/MC/MCStreamer.h"
+#include "llvm/MC/MCSubtargetInfo.h"
+#include "llvm/MC/MCTargetOptionsCommandFlags.h"
+
+// lld headers.
+#include "lld/Common/Driver.h"
+
+using namespace mlir;
+using namespace llvm;
+
+using Blob = SmallVector<char, 0>;
+
+static cl::opt<std::string> tripleName("triple", cl::desc("target triple"),
+                                       cl::value_desc("triple string"),
+                                       cl::init("amdgcn-amd-amdhsa"));
+
+// TODO(whchung): Add feature to automatically detect available AMD GCN ISA
+// version via `rocm-agent-enumerator` utility.
+static cl::opt<std::string> targetChip("target", cl::desc("target chip"),
+                                       cl::value_desc("AMDGPU ISA version"),
+                                       cl::init("gfx900"));
+
+static cl::opt<std::string> features("feature", cl::desc("target features"),
+                                     cl::value_desc("AMDGPU target features"),
+                                     cl::init("-code-object-v3"));
+
+static LogicalResult assembleIsa(const std::string isa, StringRef name,
+                                 Blob &result) {
+  raw_svector_ostream os(result);
+
+  std::string error;
+  Triple theTriple(Triple::normalize(tripleName));
+  const Target *theTarget =
+      TargetRegistry::lookupTarget(theTriple.normalize(), error);
+  if (!theTarget) {
+    WithColor::error(errs(), name) << error;
+    return failure();
+  }
+
+  SourceMgr srcMgr;
+  srcMgr.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(isa), SMLoc());
+
+  const MCTargetOptions mcOptions;
+  std::unique_ptr<MCRegisterInfo> mri(theTarget->createMCRegInfo(tripleName));
+  std::unique_ptr<MCAsmInfo> mai(
+      theTarget->createMCAsmInfo(*mri, tripleName, mcOptions));
+  mai->setRelaxELFRelocations(true);
+
+  MCObjectFileInfo mofi;
+  MCContext ctx(mai.get(), mri.get(), &mofi, &srcMgr, &mcOptions);
+  mofi.InitMCObjectFileInfo(theTriple, false, ctx, false);
+
+  SmallString<128> cwd;
+  if (!sys::fs::current_path(cwd))
+    ctx.setCompilationDir(cwd);
+
+  std::unique_ptr<MCStreamer> mcStreamer;
+  std::unique_ptr<MCInstrInfo> mcii(theTarget->createMCInstrInfo());
+  std::unique_ptr<MCSubtargetInfo> sti(
+      theTarget->createMCSubtargetInfo(tripleName, targetChip, features));
+
+  MCCodeEmitter *ce = theTarget->createMCCodeEmitter(*mcii, *mri, ctx);
+  MCAsmBackend *mab = theTarget->createMCAsmBackend(*sti, *mri, mcOptions);
+  mcStreamer.reset(theTarget->createMCObjectStreamer(
+      theTriple, ctx, std::unique_ptr<MCAsmBackend>(mab),
+      mab->createObjectWriter(os), std::unique_ptr<MCCodeEmitter>(ce), *sti,
+      mcOptions.MCRelaxAll, mcOptions.MCIncrementalLinkerCompatible,
+      /*DWARFMustBeAtTheEnd*/ false));
+  mcStreamer->setUseAssemblerInfoForParsing(true);
+
+  std::unique_ptr<MCAsmParser> parser(
+      createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
+  std::unique_ptr<MCTargetAsmParser> tap(
+      theTarget->createMCAsmParser(*sti, *parser, *mcii, mcOptions));
+
+  if (!tap) {
+    WithColor::error(errs(), name) << "assembler initialization error.\n";
+    return failure();
+  }
+
+  parser->setTargetParser(*tap);
+  parser->Run(false);
+
+  return success();
+}
+
+static LogicalResult createHsaco(const Blob &isaBlob, StringRef name,
+                                 Blob &hsacoBlob) {
+  // Save the ISA binary to a temp file.
+  int tempIsaBinaryFd = -1;
+  SmallString<128> tempIsaBinaryFilename;
+  std::error_code ec = sys::fs::createTemporaryFile(
+      "kernel", "o", tempIsaBinaryFd, tempIsaBinaryFilename);
+  if (ec) {
+    WithColor::error(errs(), name)
+        << "temporary file for ISA binary creation error.\n";
+    return failure();
+  }
+  FileRemover cleanupIsaBinary(tempIsaBinaryFilename);
+  raw_fd_ostream tempIsaBinaryOs(tempIsaBinaryFd, true);
+  tempIsaBinaryOs << isaBlob;
+  tempIsaBinaryOs.close();
+
+  // Create a temp file for HSA code object.
+  int tempHsacoFD = -1;
+  SmallString<128> tempHsacoFilename;
+  ec = sys::fs::createTemporaryFile("kernel", "hsaco", tempHsacoFD,
+                                    tempHsacoFilename);
+  if (ec) {
+    WithColor::error(errs(), name)
+        << "temporary file for HSA code object creation error.\n";
+    return failure();
+  }
+  FileRemover cleanupHsaco(tempHsacoFilename);
+
+  // Invoke lld. Expect a true return value from lld.
+  bool ret = lld::elf::link({"ld.lld", "-shared", tempIsaBinaryFilename.c_str(),
+                             "-o", tempHsacoFilename.c_str()},
+                            /*canEarlyExit=*/false, llvm::outs(), llvm::errs());
+  if (!ret) {
+    WithColor::error(errs(), name) << "lld invocation error.\n";
+    return failure();
+  }
+
+  // Load the HSA code object.
+  auto hsacoFile = mlir::openInputFile(tempHsacoFilename);
+  if (!hsacoFile) {
+    WithColor::error(errs(), name)
+        << "read HSA code object from temp file error.\n";
+    return failure();
+  }
+  hsacoBlob.assign(hsacoFile->getBuffer().begin(),
+                   hsacoFile->getBuffer().end());
+
+  return success();
+}
+
+static std::unique_ptr<llvm::Module> compileModuleToROCDLIR(Operation *m) {
+  auto llvmModule = translateModuleToROCDLIR(m);
+  // TODO(whchung): Link with ROCm-Device-Libs in case needed (ex: the Module
+  // depends on math functions).
+  return llvmModule;
+}
+
+static OwnedBlob compileISAToHsaco(const std::string isa, Location loc,
+                                   StringRef name) {
+  // ISA -> ISA in binary form via MC.
+  // Use lld to create HSA code object.
+  Blob isaBlob;
+  Blob hsacoBlob;
+
+  if (succeeded(assembleIsa(isa, name, isaBlob)) &&
+      succeeded(createHsaco(isaBlob, name, hsacoBlob)))
+    return std::make_unique<std::vector<char>>(hsacoBlob.begin(),
+                                               hsacoBlob.end());
+
+  WithColor::error(errs(), name) << "producing HSA code object error.\n";
+  return {};
+}
+
+static LogicalResult runMLIRPasses(ModuleOp m) {
+  PassManager pm(m.getContext());
+  applyPassManagerCLOptions(pm);
+
+  pm.addPass(createGpuKernelOutliningPass());
+  auto &kernelPm = pm.nest<gpu::GPUModuleOp>();
+  kernelPm.addPass(createStripDebugInfoPass());
+  kernelPm.addPass(createLowerGpuOpsToROCDLOpsPass());
+  kernelPm.addPass(createConvertGPUKernelToBlobPass(
+      compileModuleToROCDLIR, compileISAToHsaco, tripleName, targetChip,
+      features, /*gpuBinaryAnnotation=*/"rocdl.hsaco"));
+  pm.addPass(createLowerToLLVMPass());
+  pm.addPass(createConvertGpuLaunchFuncToGpuRuntimeCallsPass(
+      /*gpuBinaryAnnotation=*/"rocdl.hsaco"));
+
+  return pm.run(m);
+}
+
+int main(int argc, char **argv) {
+  registerPassManagerCLOptions();
+  mlir::registerAllDialects();
+  llvm::InitLLVM y(argc, argv);
+  llvm::InitializeAllTargetInfos();
+  llvm::InitializeAllTargetMCs();
+  llvm::InitializeAllAsmParsers();
+
+  // Initialize LLVM AMDGPU backend.
+  LLVMInitializeAMDGPUTarget();
+  LLVMInitializeAMDGPUTargetInfo();
+  LLVMInitializeAMDGPUTargetMC();
+  LLVMInitializeAMDGPUAsmPrinter();
+
+  mlir::initializeLLVMPasses();
+  return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
+}

diff  --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
new file mode 100644
index 000000000000..f49e6c91ea65
--- /dev/null
+++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
@@ -0,0 +1,143 @@
+//===- rocm-runtime-wrappers.cpp - MLIR ROCM runner wrapper library -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements C wrappers around the ROCM library for easy linking in ORC jit.
+// Also adds some debugging helpers that are helpful when writing MLIR code to
+// run on GPUs.
+//
+//===----------------------------------------------------------------------===//
+
+#include <cassert>
+#include <numeric>
+
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "hip/hip_runtime.h"
+
+namespace {
+int32_t reportErrorIfAny(hipError_t result, const char *where) {
+  if (result != hipSuccess) {
+    llvm::errs() << "HIP failed with " << result << " in " << where << "\n";
+  }
+  return result;
+}
+} // anonymous namespace
+
+extern "C" int32_t mgpuModuleLoad(void **module, void *data) {
+  int32_t err = reportErrorIfAny(
+      hipModuleLoadData(reinterpret_cast<hipModule_t *>(module), data),
+      "ModuleLoad");
+  return err;
+}
+
+extern "C" int32_t mgpuModuleGetFunction(void **function, void *module,
+                                         const char *name) {
+  return reportErrorIfAny(
+      hipModuleGetFunction(reinterpret_cast<hipFunction_t *>(function),
+                           reinterpret_cast<hipModule_t>(module), name),
+      "GetFunction");
+}
+
+// The wrapper uses intptr_t instead of ROCM's unsigned int to match
+// the type of MLIR's index type. This avoids the need for casts in the
+// generated MLIR code.
+extern "C" int32_t mgpuLaunchKernel(void *function, intptr_t gridX,
+                                    intptr_t gridY, intptr_t gridZ,
+                                    intptr_t blockX, intptr_t blockY,
+                                    intptr_t blockZ, int32_t smem, void *stream,
+                                    void **params, void **extra) {
+  return reportErrorIfAny(
+      hipModuleLaunchKernel(reinterpret_cast<hipFunction_t>(function), gridX,
+                            gridY, gridZ, blockX, blockY, blockZ, smem,
+                            reinterpret_cast<hipStream_t>(stream), params,
+                            extra),
+      "LaunchKernel");
+}
+
+extern "C" void *mgpuGetStreamHelper() {
+  hipStream_t stream;
+  reportErrorIfAny(hipStreamCreate(&stream), "StreamCreate");
+  return stream;
+}
+
+extern "C" int32_t mgpuStreamSynchronize(void *stream) {
+  return reportErrorIfAny(
+      hipStreamSynchronize(reinterpret_cast<hipStream_t>(stream)),
+      "StreamSync");
+}
+
+/// Helper functions for writing mlir example code
+
+// Allows to register byte array with the ROCM runtime. Helpful until we have
+// transfer functions implemented.
+extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
+  reportErrorIfAny(hipHostRegister(ptr, sizeBytes, /*flags=*/0),
+                   "MemHostRegister");
+}
+
+// Allows to register a MemRef with the ROCM runtime. Initializes array with
+// value. Helpful until we have transfer functions implemented.
+template <typename T>
+void mgpuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
+                               llvm::ArrayRef<int64_t> strides, T value) {
+  assert(sizes.size() == strides.size());
+  llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
+
+  std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
+                   std::multiplies<int64_t>());
+  auto count = denseStrides.front();
+
+  // Only densely packed tensors are currently supported.
+  std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
+              denseStrides.end());
+  denseStrides.back() = 1;
+  assert(strides == llvm::makeArrayRef(denseStrides));
+
+  std::fill_n(pointer, count, value);
+  mgpuMemHostRegister(pointer, count * sizeof(T));
+}
+
+extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
+  auto *desc = static_cast<StridedMemRefType<float, 1> *>(ptr);
+  auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
+  auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
+  mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f);
+}
+
+extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
+  auto *desc = static_cast<StridedMemRefType<int32_t, 1> *>(ptr);
+  auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
+  auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
+  mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123);
+}
+
+template <typename T>
+void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) {
+  reportErrorIfAny(hipSetDevice(0), "hipSetDevice");
+  reportErrorIfAny(
+      hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0),
+      "hipHostGetDevicePointer");
+}
+
+extern "C" StridedMemRefType<float, 1>
+mgpuMemGetDeviceMemRef1dFloat(float *allocated, float *aligned, int64_t offset,
+                              int64_t size, int64_t stride) {
+  float *devicePtr = nullptr;
+  mgpuMemGetDevicePointer(aligned, &devicePtr);
+  return {devicePtr, devicePtr, offset, {size}, {stride}};
+}
+
+extern "C" StridedMemRefType<int32_t, 1>
+mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned,
+                              int64_t offset, int64_t size, int64_t stride) {
+  int32_t *devicePtr = nullptr;
+  mgpuMemGetDevicePointer(aligned, &devicePtr);
+  return {devicePtr, devicePtr, offset, {size}, {stride}};
+}


        


More information about the Mlir-commits mailing list