[Mlir-commits] [mlir] 89808ce - [MLIR][mlir-spirv-cpu-runner] A SPIR-V cpu runner prototype

Lei Zhang llvmlistbot at llvm.org
Mon Oct 26 06:10:30 PDT 2020


Author: George Mitenkov
Date: 2020-10-26T09:09:29-04:00
New Revision: 89808ce7343b22586bfd0d3fafddcdbba94fbcbb

URL: https://github.com/llvm/llvm-project/commit/89808ce7343b22586bfd0d3fafddcdbba94fbcbb
DIFF: https://github.com/llvm/llvm-project/commit/89808ce7343b22586bfd0d3fafddcdbba94fbcbb.diff

LOG: [MLIR][mlir-spirv-cpu-runner] A SPIR-V cpu runner prototype

This patch introduces a SPIR-V runner. The aim is to run a gpu
kernel on a CPU via GPU -> SPIRV -> LLVM conversions. This is a first
prototype, so more features will be added in due time.

- Overview
The runner follows similar flow as the other runners in-tree. However,
having converted the kernel to SPIR-V, we encode the bind attributes of
global variables that represent kernel arguments. Then SPIR-V module is
converted to LLVM. On the host side, we emulate passing the data to device
by creating in main module globals with the same symbolic name as in kernel
module. These global variables are later linked with ones from the nested
module. We copy data from kernel arguments to globals, call the kernel
function from nested module and then copy the data back.

- Current state
At the moment, the runner is capable of running 2 modules, nested one in
another. The kernel module must contain exactly one kernel function. Also,
the runner supports rank 1 integer memref types as arguments (to be scaled).

- Enhancement of JitRunner and ExecutionEngine
To translate nested modules to LLVM IR, JitRunner and ExecutionEngine were
altered to take an optional (default to `nullptr`) function reference that
is a custom LLVM IR module builder. This allows to customize LLVM IR module
creation from MLIR modules.

Reviewed By: ftynse, mravishankar

Differential Revision: https://reviews.llvm.org/D86108

Added: 
    mlir/test/mlir-spirv-cpu-runner/CMakeLists.txt
    mlir/test/mlir-spirv-cpu-runner/double.mlir
    mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
    mlir/test/mlir-spirv-cpu-runner/mlir_test_spirv_cpu_runner_c_wrappers.cpp
    mlir/test/mlir-spirv-cpu-runner/simple_add.mlir
    mlir/tools/mlir-spirv-cpu-runner/CMakeLists.txt
    mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp

Modified: 
    mlir/CMakeLists.txt
    mlir/examples/toy/Ch6/toyc.cpp
    mlir/examples/toy/Ch7/toyc.cpp
    mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
    mlir/include/mlir/ExecutionEngine/JitRunner.h
    mlir/lib/ExecutionEngine/ExecutionEngine.cpp
    mlir/lib/ExecutionEngine/JitRunner.cpp
    mlir/test/CMakeLists.txt
    mlir/test/lit.cfg.py
    mlir/test/lit.site.cfg.py.in
    mlir/tools/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 50511fd2aef9..7c26cc9f4b95 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -42,6 +42,7 @@ add_definitions(-DMLIR_ROCM_CONVERSIONS_ENABLED=${MLIR_ROCM_CONVERSIONS_ENABLED}
 
 set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner")
 set(MLIR_ROCM_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir ROCm runner")
+set(MLIR_SPIRV_CPU_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir SPIR-V cpu runner")
 set(MLIR_VULKAN_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir Vulkan runner")
 
 option(MLIR_INCLUDE_TESTS

diff  --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp
index 2051089a18d3..d597a1f987b0 100644
--- a/mlir/examples/toy/Ch6/toyc.cpp
+++ b/mlir/examples/toy/Ch6/toyc.cpp
@@ -226,7 +226,8 @@ int runJit(mlir::ModuleOp module) {
 
   // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
   // the module.
-  auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline);
+  auto maybeEngine = mlir::ExecutionEngine::create(
+      module, /*llvmModuleBuilder=*/nullptr, optPipeline);
   assert(maybeEngine && "failed to construct an execution engine");
   auto &engine = maybeEngine.get();
 

diff  --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp
index 2eb32a7290f7..c28e2a8424dc 100644
--- a/mlir/examples/toy/Ch7/toyc.cpp
+++ b/mlir/examples/toy/Ch7/toyc.cpp
@@ -227,7 +227,8 @@ int runJit(mlir::ModuleOp module) {
 
   // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
   // the module.
-  auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline);
+  auto maybeEngine = mlir::ExecutionEngine::create(
+      module, /*llvmModuleBuilder=*/nullptr, optPipeline);
   assert(maybeEngine && "failed to construct an execution engine");
   auto &engine = maybeEngine.get();
 

diff  --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
index d0ad8326bac8..1b6b0a8f670c 100644
--- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
+++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
@@ -65,6 +65,10 @@ class ExecutionEngine {
 
   /// Creates an execution engine for the given module.
   ///
+  /// If `llvmModuleBuilder` is provided, it will be used to create LLVM module
+  /// from the given MLIR module. Otherwise, a default `translateModuleToLLVMIR`
+  /// function will be used to translate MLIR module to LLVM IR.
+  ///
   /// If `transformer` is provided, it will be called on the LLVM module during
   /// JIT-compilation and can be used, e.g., for reporting or optimization.
   ///
@@ -84,6 +88,9 @@ class ExecutionEngine {
   /// the llvm's global Perf notification listener.
   static llvm::Expected<std::unique_ptr<ExecutionEngine>>
   create(ModuleOp m,
+         llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+                                                          llvm::LLVMContext &)>
+             llvmModuleBuilder = nullptr,
          llvm::function_ref<llvm::Error(llvm::Module *)> transformer = {},
          Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel = llvm::None,
          ArrayRef<StringRef> sharedLibPaths = {}, bool enableObjectCache = true,

diff  --git a/mlir/include/mlir/ExecutionEngine/JitRunner.h b/mlir/include/mlir/ExecutionEngine/JitRunner.h
index 9e18166edc0b..2b7518c8cde2 100644
--- a/mlir/include/mlir/ExecutionEngine/JitRunner.h
+++ b/mlir/include/mlir/ExecutionEngine/JitRunner.h
@@ -18,20 +18,29 @@
 #ifndef MLIR_SUPPORT_JITRUNNER_H_
 #define MLIR_SUPPORT_JITRUNNER_H_
 
+#include "mlir/IR/Module.h"
+
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/Module.h"
 
 namespace mlir {
 
+using TranslationCallback = llvm::function_ref<std::unique_ptr<llvm::Module>(
+    ModuleOp, llvm::LLVMContext &)>;
+
 class ModuleOp;
 struct LogicalResult;
 
 // Entry point for all CPU runners. Expects the common argc/argv arguments for
-// standard C++ main functions and an mlirTransformer.
-// The latter is applied after parsing the input into MLIR IR and before passing
-// the MLIR module to the ExecutionEngine.
+// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
+/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
+/// passing the MLIR module to the ExecutionEngine.
+/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
+/// It processes MLIR module and creates LLVM IR module.
 int JitRunnerMain(
     int argc, char **argv,
-    llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer);
+    llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
+    TranslationCallback llvmModuleBuilder = nullptr);
 
 } // namespace mlir
 

diff  --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index bad433306502..9a9336e69f63 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -214,7 +214,11 @@ ExecutionEngine::ExecutionEngine(bool enableObjectCache,
                        : nullptr) {}
 
 Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
-    ModuleOp m, llvm::function_ref<Error(llvm::Module *)> transformer,
+    ModuleOp m,
+    llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+                                                     llvm::LLVMContext &)>
+        llvmModuleBuilder,
+    llvm::function_ref<Error(llvm::Module *)> transformer,
     Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel,
     ArrayRef<StringRef> sharedLibPaths, bool enableObjectCache,
     bool enableGDBNotificationListener, bool enablePerfNotificationListener) {
@@ -223,7 +227,8 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
       enablePerfNotificationListener);
 
   std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
-  auto llvmModule = translateModuleToLLVMIR(m, *ctx);
+  auto llvmModule = llvmModuleBuilder ? llvmModuleBuilder(m, *ctx)
+                                      : translateModuleToLLVMIR(m, *ctx);
   if (!llvmModule)
     return make_string_error("could not convert to LLVM IR");
   // FIXME: the triple should be passed to the translation or dialect conversion

diff  --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index 3727306f19f4..c1bb1e68a4ca 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -20,7 +20,6 @@
 #include "mlir/ExecutionEngine/ExecutionEngine.h"
 #include "mlir/ExecutionEngine/OptUtils.h"
 #include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/InitAllDialects.h"
 #include "mlir/Parser.h"
@@ -31,7 +30,6 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/LegacyPassNameParser.h"
-#include "llvm/IR/Module.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FileUtilities.h"
 #include "llvm/Support/SourceMgr.h"
@@ -134,7 +132,8 @@ static Optional<unsigned> getCommandLineOptLevel(Options &options) {
 
 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
 static Error
-compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint,
+compileAndExecute(Options &options, ModuleOp module,
+                  TranslationCallback llvmModuleBuilder, StringRef entryPoint,
                   std::function<llvm::Error(llvm::Module *)> transformer,
                   void **args) {
   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
@@ -143,8 +142,8 @@ compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint,
         static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
   SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
                                  options.clSharedLibs.end());
-  auto expectedEngine = mlir::ExecutionEngine::create(module, transformer,
-                                                      jitCodeGenOptLevel, libs);
+  auto expectedEngine = mlir::ExecutionEngine::create(
+      module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs);
   if (!expectedEngine)
     return expectedEngine.takeError();
 
@@ -165,13 +164,15 @@ compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint,
 }
 
 static Error compileAndExecuteVoidFunction(
-    Options &options, ModuleOp module, StringRef entryPoint,
+    Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
+    StringRef entryPoint,
     std::function<llvm::Error(llvm::Module *)> transformer) {
   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
   if (!mainFunction || mainFunction.empty())
     return make_string_error("entry point not found");
   void *empty = nullptr;
-  return compileAndExecute(options, module, entryPoint, transformer, &empty);
+  return compileAndExecute(options, module, llvmModuleBuilder, entryPoint,
+                           transformer, &empty);
 }
 
 template <typename Type>
@@ -196,7 +197,8 @@ Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
 }
 template <typename Type>
 Error compileAndExecuteSingleReturnFunction(
-    Options &options, ModuleOp module, StringRef entryPoint,
+    Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
+    StringRef entryPoint,
     std::function<llvm::Error(llvm::Module *)> transformer) {
   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
   if (!mainFunction || mainFunction.isExternal())
@@ -213,8 +215,8 @@ Error compileAndExecuteSingleReturnFunction(
     void *data;
   } data;
   data.data = &res;
-  if (auto error = compileAndExecute(options, module, entryPoint, transformer,
-                                     (void **)&data))
+  if (auto error = compileAndExecute(options, module, llvmModuleBuilder,
+                                     entryPoint, transformer, (void **)&data))
     return error;
 
   // Intentional printing of the output so we can test.
@@ -223,13 +225,16 @@ Error compileAndExecuteSingleReturnFunction(
   return Error::success();
 }
 
-/// Entry point for all CPU runners. Expects the common argc/argv
-/// arguments for standard C++ main functions and an mlirTransformer.
-/// The latter is applied after parsing the input into MLIR IR and
-/// before passing the MLIR module to the ExecutionEngine.
+/// Entry point for all CPU runners. Expects the common argc/argv arguments for
+/// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
+/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
+/// passing the MLIR module to the ExecutionEngine.
+/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
+/// It processes MLIR module and creates LLVM IR module.
 int mlir::JitRunnerMain(
     int argc, char **argv,
-    function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) {
+    function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
+    TranslationCallback llvmModuleBuilder) {
   // Create the options struct containing the command line options for the
   // runner. This must come before the command line options are parsed.
   Options options;
@@ -289,7 +294,7 @@ int mlir::JitRunnerMain(
 
   // Get the function used to compile and execute the module.
   using CompileAndExecuteFnT =
-      Error (*)(Options &, ModuleOp, StringRef,
+      Error (*)(Options &, ModuleOp, TranslationCallback, StringRef,
                 std::function<llvm::Error(llvm::Module *)>);
   auto compileAndExecuteFn =
       StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
@@ -301,7 +306,7 @@ int mlir::JitRunnerMain(
 
   Error error =
       compileAndExecuteFn
-          ? compileAndExecuteFn(options, m.get(),
+          ? compileAndExecuteFn(options, m.get(), llvmModuleBuilder,
                                 options.mainFuncName.getValue(), transformer)
           : make_string_error("unsupported function type");
 

diff  --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index cc4a9b1af10d..b9bc6bfb960a 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -20,9 +20,10 @@ set(MLIR_DIALECT_LINALG_INTEGRATION_TEST_LIB_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTOR
 set(MLIR_RUNNER_UTILS_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
 
 # Passed to lit.site.cfg.py.in to set up the path where to find the libraries
-# for the mlir cuda / rocm / vulkan runner tests.
+# for the mlir cuda / rocm / spirv / vulkan runner tests.
 set(MLIR_CUDA_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
 set(MLIR_ROCM_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
+set(MLIR_SPIRV_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
 set(MLIR_VULKAN_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
 
 configure_lit_site_cfg(
@@ -81,6 +82,14 @@ if(MLIR_ROCM_RUNNER_ENABLED)
   )
 endif()
 
+if(MLIR_SPIRV_CPU_RUNNER_ENABLED)
+  add_subdirectory(mlir-spirv-cpu-runner)
+  list(APPEND MLIR_TEST_DEPENDS
+    mlir-spirv-cpu-runner
+    mlir_test_spirv_cpu_runner_c_wrappers
+  )
+endif()
+
 if(MLIR_VULKAN_RUNNER_ENABLED)
   list(APPEND MLIR_TEST_DEPENDS
     mlir-vulkan-runner

diff  --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 67ca6692d10a..2a4bf85f9770 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -74,6 +74,7 @@
     ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'),
     ToolSubst('%mlir_runner_utils_dir', config.mlir_runner_utils_dir, unresolved='ignore'),
     ToolSubst('%rocm_wrapper_library_dir', config.rocm_wrapper_library_dir, unresolved='ignore'),
+    ToolSubst('%spirv_wrapper_library_dir', config.spirv_wrapper_library_dir, unresolved='ignore'),
     ToolSubst('%vulkan_wrapper_library_dir', config.vulkan_wrapper_library_dir, unresolved='ignore'),
 ])
 

diff  --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index 1dc823239b08..cdb79616ecc0 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -41,6 +41,8 @@ config.enable_cuda_runner = @MLIR_CUDA_RUNNER_ENABLED@
 config.run_rocm_tests = @MLIR_ROCM_CONVERSIONS_ENABLED@
 config.rocm_wrapper_library_dir = "@MLIR_ROCM_WRAPPER_LIBRARY_DIR@"
 config.enable_rocm_runner = @MLIR_ROCM_RUNNER_ENABLED@
+config.spirv_wrapper_library_dir = "@MLIR_SPIRV_WRAPPER_LIBRARY_DIR@"
+config.enable_spirv_cpu_runner = @MLIR_SPIRV_CPU_RUNNER_ENABLED@
 config.vulkan_wrapper_library_dir = "@MLIR_VULKAN_WRAPPER_LIBRARY_DIR@"
 config.enable_vulkan_runner = @MLIR_VULKAN_RUNNER_ENABLED@
 config.enable_bindings_python = @MLIR_BINDINGS_PYTHON_ENABLED@

diff  --git a/mlir/test/mlir-spirv-cpu-runner/CMakeLists.txt b/mlir/test/mlir-spirv-cpu-runner/CMakeLists.txt
new file mode 100644
index 000000000000..c7521db57a38
--- /dev/null
+++ b/mlir/test/mlir-spirv-cpu-runner/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_OPTIONAL_SOURCES
+  mlir_test_spirv_cpu_runner_c_wrappers.cpp
+  )
+
+add_llvm_library(mlir_test_spirv_cpu_runner_c_wrappers SHARED mlir_test_spirv_cpu_runner_c_wrappers.cpp)
+target_compile_definitions(mlir_test_spirv_cpu_runner_c_wrappers PRIVATE mlir_test_spirv_cpu_runner_c_wrappers_EXPORTS)

diff  --git a/mlir/test/mlir-spirv-cpu-runner/double.mlir b/mlir/test/mlir-spirv-cpu-runner/double.mlir
new file mode 100644
index 000000000000..8251375edc4f
--- /dev/null
+++ b/mlir/test/mlir-spirv-cpu-runner/double.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-spirv-cpu-runner %s -e main --entry-point-result=void --shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%spirv_wrapper_library_dir/libmlir_test_spirv_cpu_runner_c_wrappers%shlibext
+
+// CHECK: [8,  8,  8,  8,  8,  8]
+module attributes {
+  gpu.container_module,
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_variable_pointers]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+  gpu.module @kernels {
+    gpu.func @double(%arg0 : memref<6xi32>, %arg1 : memref<6xi32>)
+      kernel attributes { spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} {
+      %factor = constant 2 : i32
+
+      %i0 = constant 0 : index
+      %i1 = constant 1 : index
+      %i2 = constant 2 : index
+      %i3 = constant 3 : index
+      %i4 = constant 4 : index
+      %i5 = constant 5 : index
+
+      %x0 = load %arg0[%i0] : memref<6xi32>
+      %x1 = load %arg0[%i1] : memref<6xi32>
+      %x2 = load %arg0[%i2] : memref<6xi32>
+      %x3 = load %arg0[%i3] : memref<6xi32>
+      %x4 = load %arg0[%i4] : memref<6xi32>
+      %x5 = load %arg0[%i5] : memref<6xi32>
+
+      %y0 = muli %x0, %factor : i32
+      %y1 = muli %x1, %factor : i32
+      %y2 = muli %x2, %factor : i32
+      %y3 = muli %x3, %factor : i32
+      %y4 = muli %x4, %factor : i32
+      %y5 = muli %x5, %factor : i32
+
+      store %y0, %arg1[%i0] : memref<6xi32>
+      store %y1, %arg1[%i1] : memref<6xi32>
+      store %y2, %arg1[%i2] : memref<6xi32>
+      store %y3, %arg1[%i3] : memref<6xi32>
+      store %y4, %arg1[%i4] : memref<6xi32>
+      store %y5, %arg1[%i5] : memref<6xi32>
+      gpu.return
+    }
+  }
+  func @main() {
+    %input = alloc() : memref<6xi32>
+    %output = alloc() : memref<6xi32>
+    %four = constant 4 : i32
+    %zero = constant 0 : i32
+    %input_casted = memref_cast %input : memref<6xi32> to memref<?xi32>
+    %output_casted = memref_cast %output : memref<6xi32> to memref<?xi32>
+    call @fillI32Buffer(%input_casted, %four) : (memref<?xi32>, i32) -> ()
+    call @fillI32Buffer(%output_casted, %zero) : (memref<?xi32>, i32) -> ()
+
+    %one = constant 1 : index
+    "gpu.launch_func"(%one, %one, %one,
+                      %one, %one, %one,
+                      %input, %output) { kernel = @kernels::@double }
+        : (index, index, index, index, index, index, memref<6xi32>, memref<6xi32>) -> ()
+    %result = memref_cast %output : memref<6xi32> to memref<*xi32>
+    call @print_memref_i32(%result) : (memref<*xi32>) -> ()
+    return
+  }
+
+  func @fillI32Buffer(%arg0 : memref<?xi32>, %arg1 : i32)
+  func @print_memref_i32(%ptr : memref<*xi32>)
+}

diff  --git a/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg b/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
new file mode 100644
index 000000000000..f218c0b26c9a
--- /dev/null
+++ b/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
@@ -0,0 +1,8 @@
+import sys
+
+# FIXME: llvm orc does not support the COFF rtld.
+if sys.platform == 'win32':
+    config.unsupported = True
+
+if not config.enable_spirv_cpu_runner:
+  config.unsupported = True

diff  --git a/mlir/test/mlir-spirv-cpu-runner/mlir_test_spirv_cpu_runner_c_wrappers.cpp b/mlir/test/mlir-spirv-cpu-runner/mlir_test_spirv_cpu_runner_c_wrappers.cpp
new file mode 100644
index 000000000000..82179a6dc770
--- /dev/null
+++ b/mlir/test/mlir-spirv-cpu-runner/mlir_test_spirv_cpu_runner_c_wrappers.cpp
@@ -0,0 +1,38 @@
+//===- mlir_test_spirv_cpu_runner_c_wrappers.cpp - Runner testing library -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A small library for SPIR-V cpu runner testing.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/RunnerUtils.h"
+
+extern "C" void
+_mlir_ciface_fillI32Buffer(StridedMemRefType<int32_t, 1> *mem_ref,
+                           int32_t value) {
+  std::fill_n(mem_ref->basePtr, mem_ref->sizes[0], value);
+}
+
+extern "C" void
+_mlir_ciface_fillF32Buffer1D(StridedMemRefType<float, 1> *mem_ref,
+                             float value) {
+  std::fill_n(mem_ref->basePtr, mem_ref->sizes[0], value);
+}
+
+extern "C" void
+_mlir_ciface_fillF32Buffer2D(StridedMemRefType<float, 2> *mem_ref,
+                             float value) {
+  std::fill_n(mem_ref->basePtr, mem_ref->sizes[0] * mem_ref->sizes[1], value);
+}
+
+extern "C" void
+_mlir_ciface_fillF32Buffer3D(StridedMemRefType<float, 3> *mem_ref,
+                             float value) {
+  std::fill_n(mem_ref->basePtr,
+              mem_ref->sizes[0] * mem_ref->sizes[1] * mem_ref->sizes[2], value);
+}

diff  --git a/mlir/test/mlir-spirv-cpu-runner/simple_add.mlir b/mlir/test/mlir-spirv-cpu-runner/simple_add.mlir
new file mode 100644
index 000000000000..476c459423d6
--- /dev/null
+++ b/mlir/test/mlir-spirv-cpu-runner/simple_add.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-spirv-cpu-runner %s -e main --entry-point-result=void --shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%spirv_wrapper_library_dir/libmlir_test_spirv_cpu_runner_c_wrappers%shlibext
+
+// CHECK: [[[7.7,    0,    0], [7.7,    0,    0], [7.7,    0,    0]], [[0,    7.7,    0], [0,    7.7,    0], [0,    7.7,    0]], [[0,    0,    7.7], [0,    0,    7.7], [0,    0,    7.7]]]
+module attributes {
+  gpu.container_module,
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_8bit_storage]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+  gpu.module @kernels {
+    gpu.func @sum(%arg0 : memref<3xf32>, %arg1 : memref<3x3xf32>, %arg2 :  memref<3x3x3xf32>)
+      kernel attributes { spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} {
+      %i0 = constant 0 : index
+      %i1 = constant 1 : index
+      %i2 = constant 2 : index
+
+      %x = load %arg0[%i0] : memref<3xf32>
+      %y = load %arg1[%i0, %i0] : memref<3x3xf32>
+      %sum = addf %x, %y : f32
+
+      store %sum, %arg2[%i0, %i0, %i0] : memref<3x3x3xf32>
+      store %sum, %arg2[%i0, %i1, %i0] : memref<3x3x3xf32>
+      store %sum, %arg2[%i0, %i2, %i0] : memref<3x3x3xf32>
+      store %sum, %arg2[%i1, %i0, %i1] : memref<3x3x3xf32>
+      store %sum, %arg2[%i1, %i1, %i1] : memref<3x3x3xf32>
+      store %sum, %arg2[%i1, %i2, %i1] : memref<3x3x3xf32>
+      store %sum, %arg2[%i2, %i0, %i2] : memref<3x3x3xf32>
+      store %sum, %arg2[%i2, %i1, %i2] : memref<3x3x3xf32>
+      store %sum, %arg2[%i2, %i2, %i2] : memref<3x3x3xf32>
+      gpu.return
+    }
+  }
+
+  func @main() {
+    %input1 = alloc() : memref<3xf32>
+    %input2 = alloc() : memref<3x3xf32>
+    %output = alloc() : memref<3x3x3xf32>
+    %0 = constant 0.0 : f32
+    %3 = constant 3.4 : f32
+    %4 = constant 4.3 : f32
+    %input1_casted = memref_cast %input1 : memref<3xf32> to memref<?xf32>
+    %input2_casted = memref_cast %input2 : memref<3x3xf32> to memref<?x?xf32>
+    %output_casted = memref_cast %output : memref<3x3x3xf32> to memref<?x?x?xf32>
+    call @fillF32Buffer1D(%input1_casted, %3) : (memref<?xf32>, f32) -> ()
+    call @fillF32Buffer2D(%input2_casted, %4) : (memref<?x?xf32>, f32) -> ()
+    call @fillF32Buffer3D(%output_casted, %0) : (memref<?x?x?xf32>, f32) -> ()
+
+    %one = constant 1 : index
+    "gpu.launch_func"(%one, %one, %one,
+                      %one, %one, %one,
+                      %input1, %input2, %output) { kernel = @kernels::@sum }
+        : (index, index, index, index, index, index, memref<3xf32>, memref<3x3xf32>, memref<3x3x3xf32>) -> ()
+    %result = memref_cast %output : memref<3x3x3xf32> to memref<*xf32>
+    call @print_memref_f32(%result) : (memref<*xf32>) -> ()
+    return
+  }
+  func @fillF32Buffer1D(%arg0 : memref<?xf32>, %arg1 : f32)
+  func @fillF32Buffer2D(%arg0 : memref<?x?xf32>, %arg1 : f32)
+  func @fillF32Buffer3D(%arg0 : memref<?x?x?xf32>, %arg1 : f32)
+  func @print_memref_f32(%arg0 : memref<*xf32>)
+}

diff  --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index 23a2fcbc14eb..ab59514ef6a7 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -5,5 +5,6 @@ add_subdirectory(mlir-opt)
 add_subdirectory(mlir-reduce)
 add_subdirectory(mlir-rocm-runner)
 add_subdirectory(mlir-shlib)
+add_subdirectory(mlir-spirv-cpu-runner)
 add_subdirectory(mlir-translate)
 add_subdirectory(mlir-vulkan-runner)
\ No newline at end of file

diff  --git a/mlir/tools/mlir-spirv-cpu-runner/CMakeLists.txt b/mlir/tools/mlir-spirv-cpu-runner/CMakeLists.txt
new file mode 100644
index 000000000000..69080ae66dce
--- /dev/null
+++ b/mlir/tools/mlir-spirv-cpu-runner/CMakeLists.txt
@@ -0,0 +1,32 @@
+set(LLVM_OPTIONAL_SOURCES
+  mlir-spirv-cpu-runner.cpp
+  )
+
+if (MLIR_SPIRV_CPU_RUNNER_ENABLED)
+  message(STATUS "Building SPIR-V cpu runner")
+
+  add_llvm_tool(mlir-spirv-cpu-runner
+    mlir-spirv-cpu-runner.cpp
+  )
+
+  llvm_update_compile_flags(mlir-spirv-cpu-runner)
+
+  get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
+  get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+
+  target_link_libraries(mlir-spirv-cpu-runner PRIVATE
+    ${conversion_libs}
+    ${dialect_libs}
+    MLIRAnalysis
+    MLIREDSC
+    MLIRExecutionEngine
+    MLIRIR
+    MLIRJitRunner
+    MLIRLLVMIR
+    MLIRParser
+    MLIRTargetLLVMIR
+    MLIRTransforms
+    MLIRTranslation
+    MLIRSupport
+  )
+endif()

diff  --git a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
new file mode 100644
index 000000000000..9979801023b2
--- /dev/null
+++ b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
@@ -0,0 +1,90 @@
+//===- mlir-spirv-cpu-runner.cpp - MLIR SPIR-V Execution on CPU -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Main entry point to a command line utility that executes an MLIR file on the
+// CPU by translating MLIR GPU module and host part to LLVM IR before
+// JIT-compiling and executing.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
+#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/ExecutionEngine/JitRunner.h"
+#include "mlir/ExecutionEngine/OptUtils.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
+
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Linker/Linker.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/TargetSelect.h"
+
+using namespace mlir;
+
+/// A utility function that builds llvm::Module from two nested MLIR modules.
+///
+/// module @main {
+///   module @kernel {
+///     // Some ops
+///   }
+///   // Some other ops
+/// }
+///
+/// Each of these two modules is translated to LLVM IR module, then they are
+/// linked together and returned.
+static std::unique_ptr<llvm::Module>
+convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) {
+  // Verify that there is only one nested module.
+  auto modules = module.getOps<ModuleOp>();
+  if (!llvm::hasSingleElement(modules)) {
+    module.emitError("The module must contain exactly one nested module");
+    return nullptr;
+  }
+
+  // Translate nested module and erase it.
+  ModuleOp nested = *modules.begin();
+  std::unique_ptr<llvm::Module> kernelModule =
+      translateModuleToLLVMIR(nested, context);
+  nested.erase();
+
+  std::unique_ptr<llvm::Module> mainModule =
+      translateModuleToLLVMIR(module, context);
+  llvm::Linker::linkModules(*mainModule, std::move(kernelModule));
+  return mainModule;
+}
+
+static LogicalResult runMLIRPasses(ModuleOp module) {
+  PassManager passManager(module.getContext());
+  applyPassManagerCLOptions(passManager);
+  passManager.addPass(createGpuKernelOutliningPass());
+  passManager.addPass(createConvertGPUToSPIRVPass());
+
+  OpPassManager &nestedPM = passManager.nest<spirv::ModuleOp>();
+  nestedPM.addPass(spirv::createLowerABIAttributesPass());
+  nestedPM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
+  passManager.addPass(createLowerHostCodeToLLVMPass());
+  passManager.addPass(createConvertSPIRVToLLVMPass());
+  return passManager.run(module);
+}
+
+int main(int argc, char **argv) {
+  llvm::InitLLVM y(argc, argv);
+
+  llvm::InitializeNativeTarget();
+  llvm::InitializeNativeTargetAsmPrinter();
+  mlir::initializeLLVMPasses();
+
+  return mlir::JitRunnerMain(argc, argv, &runMLIRPasses, &convertMLIRModule);
+}


        


More information about the Mlir-commits mailing list