[Mlir-commits] [mlir] cae4067 - [MLIR][mlir-spirv-cpu-runner] A pass to emulate a call to kernel in LLVM

Lei Zhang llvmlistbot at llvm.org
Mon Oct 26 05:14:46 PDT 2020


Author: George Mitenkov
Date: 2020-10-26T08:11:04-04:00
New Revision: cae4067ec1cdf7846aa46dab13d3bc1f58b76016

URL: https://github.com/llvm/llvm-project/commit/cae4067ec1cdf7846aa46dab13d3bc1f58b76016
DIFF: https://github.com/llvm/llvm-project/commit/cae4067ec1cdf7846aa46dab13d3bc1f58b76016.diff

LOG: [MLIR][mlir-spirv-cpu-runner] A pass to emulate a call to kernel in LLVM

This patch introduces a pass for running
`mlir-spirv-cpu-runner` - LowerHostCodeToLLVMPass.

This pass emulates `gpu.launch_func` call in LLVM dialect and lowers
the host module code to LLVM. It removes the `gpu.module`, creates a
sequence of global variables that are later linked to the varables
in the kernel module, as well as a series of copies to/from
them to emulate the memory transfer to/from the host or to/from the
device sides. It also converts the remaining Standard dialect into
LLVM dialect, emitting C wrappers.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D86112

Added: 
    mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
    mlir/test/Conversion/SPIRVToLLVM/lower-host-to-llvm-calls.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h
    mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2458dca8b49a..eb073fc8f26c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -113,6 +113,12 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
   ];
 }
 
+def LowerHostCodeToLLVM : Pass<"lower-host-to-llvm", "ModuleOp"> {
+  let summary = "Lowers the host module code and `gpu.launch_func` to LLVM";
+  let constructor = "mlir::createLowerHostCodeToLLVMPass()";
+  let dependentDialects = ["LLVM::LLVMDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // GPUToNVVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h
index 9cfce928ff6e..9525a1eac984 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h
@@ -20,6 +20,16 @@ class ModuleOp;
 template <typename T>
 class OperationPass;
 
+/// Creates a pass to emulate `gpu.launch_func` call in LLVM dialect and lower
+/// the host module code to LLVM.
+///
+/// This transformation creates a sequence of global variables that are later
+/// linked to the varables in the kernel module, and a series of copies to/from
+/// them to emulate the memory transfer from the host or to the device sides. It
+/// also converts the remaining Standard dialect into LLVM dialect, emitting C
+/// wrappers.
+std::unique_ptr<OperationPass<ModuleOp>> createLowerHostCodeToLLVMPass();
+
 /// Creates a pass to convert SPIR-V operations to the LLVMIR dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertSPIRVToLLVMPass();
 

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt
index adf835c25a23..5c18928ad37a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_conversion_library(MLIRSPIRVToLLVM
+  ConvertLaunchFuncToLLVMCalls.cpp
   ConvertSPIRVToLLVM.cpp
   ConvertSPIRVToLLVMPass.cpp
 
@@ -10,6 +11,7 @@ add_mlir_conversion_library(MLIRSPIRVToLLVM
   intrinsics_gen
 
   LINK_LIBS PUBLIC
+  MLIRGPU
   MLIRSPIRV
   MLIRLLVMIR
   MLIRStandardToLLVM

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
new file mode 100644
index 000000000000..a850c9badc8d
--- /dev/null
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -0,0 +1,307 @@
+//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements passes to convert `gpu.launch_func` op into a sequence
+// of LLVM calls that emulate the host and device sides.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
+#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+static constexpr const char kSPIRVModule[] = "__spv__";
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Returns the string name of the `DescriptorSet` decoration.
+static std::string descriptorSetName() {
+  return llvm::convertToSnakeFromCamelCase(
+      stringifyDecoration(spirv::Decoration::DescriptorSet));
+}
+
+/// Returns the string name of the `Binding` decoration.
+static std::string bindingName() {
+  return llvm::convertToSnakeFromCamelCase(
+      stringifyDecoration(spirv::Decoration::Binding));
+}
+
+/// Calculates the index of the kernel's operand that is represented by the
+/// given global variable with the `bind` attribute. We assume that the index of
+/// each kernel's operand is mapped to (descriptorSet, binding) by the map:
+///   i -> (0, i)
+/// which is implemented under `LowerABIAttributesPass`.
+static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
+  IntegerAttr binding = op.getAttrOfType<IntegerAttr>(bindingName());
+  return binding.getInt();
+}
+
+/// Copies the given number of bytes from src to dst pointers.
+static void copy(Location loc, Value dst, Value src, Value size,
+                 OpBuilder &builder) {
+  MLIRContext *context = builder.getContext();
+  auto llvmI1Type = LLVM::LLVMType::getInt1Ty(context);
+  Value isVolatile = builder.create<LLVM::ConstantOp>(
+      loc, llvmI1Type, builder.getBoolAttr(false));
+  builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
+}
+
+/// Encodes the binding and descriptor set numbers into a new symbolic name.
+/// The name is specified by
+///   {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
+/// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
+/// binding numbers.
+static std::string
+createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
+                                 StringRef kernelModuleName) {
+  IntegerAttr descriptorSet =
+      op.getAttrOfType<IntegerAttr>(descriptorSetName());
+  IntegerAttr binding = op.getAttrOfType<IntegerAttr>(bindingName());
+  return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
+                       kernelModuleName.str(), op.sym_name().str(),
+                       std::to_string(descriptorSet.getInt()),
+                       std::to_string(binding.getInt()));
+}
+
+/// Returns true if the given global variable has both a descriptor set number
+/// and a binding number.
+static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
+  IntegerAttr descriptorSet =
+      op.getAttrOfType<IntegerAttr>(descriptorSetName());
+  IntegerAttr binding = op.getAttrOfType<IntegerAttr>(bindingName());
+  return descriptorSet && binding;
+}
+
+/// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
+/// arguments from the given SPIR-V module. We assume that the module contains a
+/// single entry point function. Hence, all `spv.globalVariable`s with a bind
+/// attribute are kernel arguments.
+static LogicalResult getKernelGlobalVariables(
+    spirv::ModuleOp module,
+    DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
+  auto entryPoints = module.getOps<spirv::EntryPointOp>();
+  if (!llvm::hasSingleElement(entryPoints)) {
+    return module.emitError(
+        "The module must contain exactly one entry point function");
+  }
+  auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
+  for (auto globalOp : globalVariables) {
+    if (hasDescriptorSetAndBinding(globalOp))
+      globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
+  }
+  return success();
+}
+
+/// Encodes the SPIR-V module's symbolic name into the name of the entry point
+/// function.
+static LogicalResult encodeKernelName(spirv::ModuleOp module) {
+  StringRef spvModuleName = module.sym_name().getValue();
+  // We already know that the module contains exactly one entry point function
+  // based on `getKernelGlobalVariables()` call. Update this function's name
+  // to:
+  //   {spv_module_name}_{function_name}
+  auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
+  StringRef funcName = entryPoint.fn();
+  auto funcOp = module.lookupSymbol<spirv::FuncOp>(funcName);
+  std::string newFuncName = spvModuleName.str() + "_" + funcName.str();
+  if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
+    return failure();
+  SymbolTable::setSymbolName(funcOp, newFuncName);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Structure to group information about the variables being copied.
+struct CopyInfo {
+  Value dst;
+  Value src;
+  Value size;
+};
+
+/// This pattern emulates a call to the kernel in LLVM dialect. For that, we
+/// copy the data to the global variable (emulating device side), call the
+/// kernel as a normal void LLVM function, and copy the data back (emulating the
+/// host side).
+class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
+  using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    gpu::LaunchFuncOp launchOp = cast<gpu::LaunchFuncOp>(op);
+    MLIRContext *context = rewriter.getContext();
+    auto module = launchOp.getParentOfType<ModuleOp>();
+
+    // Get the SPIR-V module that represents the gpu kernel module. The module
+    // is named:
+    //   __spv__{kernel_module_name}
+    // based on GPU to SPIR-V conversion.
+    StringRef kernelModuleName = launchOp.getKernelModuleName();
+    std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
+    auto spvModule = module.lookupSymbol<spirv::ModuleOp>(spvModuleName);
+    if (!spvModule) {
+      return launchOp.emitOpError("SPIR-V kernel module '")
+             << spvModuleName << "' is not found";
+    }
+
+    // Declare kernel function in the main module so that it later can be linked
+    // with its definition from the kernel module. We know that the kernel
+    // function would have no arguments and the data is passed via global
+    // variables. The name of the kernel will be
+    //   {spv_module_name}_{kernel_function_name}
+    // to avoid symbolic name conflicts.
+    StringRef kernelFuncName = launchOp.getKernelName();
+    std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
+    auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(newKernelFuncName);
+    if (!kernelFunc) {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(module.getBody());
+      kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
+          rewriter.getUnknownLoc(), newKernelFuncName,
+          LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context),
+                                        ArrayRef<LLVM::LLVMType>(),
+                                        /*isVarArg=*/false));
+      rewriter.setInsertionPoint(launchOp);
+    }
+
+    // Get all global variables associated with the kernel operands.
+    DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
+    if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
+      return failure();
+
+    // Traverse kernel operands that were converted to MemRefDescriptors. For
+    // each operand, create a global variable and copy data from operand to it.
+    Location loc = launchOp.getLoc();
+    SmallVector<CopyInfo, 4> copyInfo;
+    auto numKernelOperands = launchOp.getNumKernelOperands();
+    auto kernelOperands = operands.take_back(numKernelOperands);
+    for (auto operand : llvm::enumerate(kernelOperands)) {
+      // Check if the kernel's opernad is a ranked memref.
+      auto memRefType = launchOp.getKernelOperand(operand.index())
+                            .getType()
+                            .dyn_cast<MemRefType>();
+      if (!memRefType)
+        return failure();
+
+      // Calculate the size of the memref and get the pointer to the allocated
+      // buffer.
+      SmallVector<Value, 4> sizes;
+      getMemRefDescriptorSizes(loc, memRefType, operand.value(), rewriter,
+                               sizes);
+      Value size = getCumulativeSizeInBytes(loc, memRefType.getElementType(),
+                                            sizes, rewriter);
+      MemRefDescriptor descriptor(operand.value());
+      Value src = descriptor.allocatedPtr(rewriter, loc);
+
+      // Get the global variable in the SPIR-V module that is associated with
+      // the kernel operand. Construct its new name and create a corresponding
+      // LLVM dialect global variable.
+      spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
+      auto pointeeType =
+          spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
+      auto dstGlobalType = typeConverter.convertType(pointeeType);
+      if (!dstGlobalType)
+        return failure();
+      std::string name =
+          createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
+      // Check if this variable has already been created.
+      auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
+      if (!dstGlobal) {
+        OpBuilder::InsertionGuard guard(rewriter);
+        rewriter.setInsertionPointToStart(module.getBody());
+        dstGlobal = rewriter.create<LLVM::GlobalOp>(
+            loc, dstGlobalType.cast<LLVM::LLVMType>(),
+            /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute());
+        rewriter.setInsertionPoint(launchOp);
+      }
+
+      // Copy the data from src operand pointer to dst global variable. Save
+      // src, dst and size so that we can copy data back after emulating the
+      // kernel call.
+      Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal);
+      copy(loc, dst, src, size, rewriter);
+
+      CopyInfo info;
+      info.dst = dst;
+      info.src = src;
+      info.size = size;
+      copyInfo.push_back(info);
+    }
+    // Create a call to the kernel and copy the data back.
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
+                                              ArrayRef<Value>());
+    for (CopyInfo info : copyInfo)
+      copy(loc, info.src, info.dst, info.size, rewriter);
+    return success();
+  }
+};
+
+class LowerHostCodeToLLVM
+    : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> {
+public:
+  void runOnOperation() override {
+    ModuleOp module = getOperation();
+
+    // Erase the GPU module.
+    for (auto gpuModule :
+         llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
+      gpuModule.erase();
+
+    // Specify options to lower Standard to LLVM and pull in the conversion
+    // patterns.
+    LowerToLLVMOptions options = {
+        /*useBarePtrCallConv=*/false,
+        /*emitCWrappers=*/true,
+        /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout};
+    auto *context = module.getContext();
+    OwningRewritePatternList patterns;
+    LLVMTypeConverter typeConverter(context, options);
+    populateStdToLLVMConversionPatterns(typeConverter, patterns);
+    patterns.insert<GPULaunchLowering>(typeConverter);
+
+    // Pull in SPIR-V type conversion patterns to convert SPIR-V global
+    // variable's type to LLVM dialect type.
+    populateSPIRVToLLVMTypeConversion(typeConverter);
+
+    ConversionTarget target(*context);
+    target.addLegalDialect<LLVM::LLVMDialect>();
+    if (failed(applyPartialConversion(module, target, patterns)))
+      signalPassFailure();
+
+    // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
+    // conflicts.
+    for (auto spvModule : module.getOps<spirv::ModuleOp>())
+      encodeKernelName(spvModule);
+  }
+};
+} // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+mlir::createLowerHostCodeToLLVMPass() {
+  return std::make_unique<LowerHostCodeToLLVM>();
+}

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/lower-host-to-llvm-calls.mlir b/mlir/test/Conversion/SPIRVToLLVM/lower-host-to-llvm-calls.mlir
new file mode 100644
index 000000000000..176d860e9549
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/lower-host-to-llvm-calls.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt --lower-host-to-llvm %s | FileCheck %s
+  
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_variable_pointers]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+
+  //       CHECK: llvm.mlir.global linkonce @__spv__foo_bar_arg_0_descriptor_set0_binding0() : !llvm.struct<(array<6 x i32>)>
+  //       CHECK: llvm.func @__spv__foo_bar()
+
+  //       CHECK: spv.module @__spv__foo
+  //       CHECK:   spv.globalVariable @bar_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<(!spv.array<6 x i32, stride=4> [0])>, StorageBuffer>
+  //       CHECK:   spv.func @__spv__foo_bar
+  
+  //       CHECK:   spv.EntryPoint "GLCompute" @__spv__foo_bar
+  //       CHECK:   spv.ExecutionMode @__spv__foo_bar "LocalSize", 1, 1, 1
+
+  // CHECK-LABEL: @main
+  //       CHECK:   %[[SRC:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+  //  CHECK-NEXT:   %[[DEST:.*]] = llvm.mlir.addressof @__spv__foo_bar_arg_0_descriptor_set0_binding0 : !llvm.ptr<struct<(array<6 x i32>)>>
+  //  CHECK-NEXT:   llvm.mlir.constant(false) : !llvm.i1
+  //  CHECK-NEXT:   "llvm.intr.memcpy"(%[[DEST]], %[[SRC]], %[[SIZE:.*]], %{{.*}}) : (!llvm.ptr<struct<(array<6 x i32>)>>, !llvm.ptr<i32>, !llvm.i64, !llvm.i1) -> ()
+  //  CHECK-NEXT:   llvm.call @__spv__foo_bar() : () -> ()
+  //  CHECK-NEXT:   llvm.mlir.constant(false) : !llvm.i1
+  //  CHECK-NEXT:   "llvm.intr.memcpy"(%[[SRC]], %[[DEST]], %[[SIZE]], %{{.*}}) : (!llvm.ptr<i32>, !llvm.ptr<struct<(array<6 x i32>)>>, !llvm.i64, !llvm.i1) -> ()
+
+  spv.module @__spv__foo Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_variable_pointers]> {
+    spv.globalVariable @bar_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<(!spv.array<6 x i32, stride=4> [0])>, StorageBuffer>
+    spv.func @bar() "None" attributes {workgroup_attributions = 0 : i64} {
+      %0 = spv._address_of @bar_arg_0 : !spv.ptr<!spv.struct<(!spv.array<6 x i32, stride=4> [0])>, StorageBuffer>
+      spv.Return
+    }
+    spv.EntryPoint "GLCompute" @bar
+    spv.ExecutionMode @bar "LocalSize", 1, 1, 1
+  }
+
+  gpu.module @foo {
+    gpu.func @bar(%arg0: memref<6xi32>) kernel attributes {spv.entry_point_abi = {local_size = dense<1> : vector<3xi32>}} {
+      gpu.return
+    }
+  }
+
+  func @main() {
+    %buffer = alloc() : memref<6xi32>
+    %one = constant 1 : index
+    "gpu.launch_func"(%one, %one, %one,
+                      %one, %one, %one,
+                      %buffer) {kernel = @foo::@bar} : (index, index, index, index, index, index, memref<6xi32>) -> ()
+    return
+  }
+}


        


More information about the Mlir-commits mailing list