[Mlir-commits] [mlir] 8ae074b - [mlir][gpu] Add the Select Object compilation attribute.

Fabian Mora llvmlistbot at llvm.org
Fri Aug 11 15:00:42 PDT 2023


Author: Fabian Mora
Date: 2023-08-11T22:00:35Z
New Revision: 8ae074b19597e38c55273ebe368e05ae3a425214

URL: https://github.com/llvm/llvm-project/commit/8ae074b19597e38c55273ebe368e05ae3a425214
DIFF: https://github.com/llvm/llvm-project/commit/8ae074b19597e38c55273ebe368e05ae3a425214.diff

LOG: [mlir][gpu] Add the Select Object compilation attribute.

**For an explanation of these patches see D154153.**

Commit message:
This patch adds the default offloading handler for GPU binary ops: `#gpu.select_object`,
it selects the object to embed based on an index or a target attribute, embedding
the object as a global string and launches the kernel using the scheme used in the
GPU to LLVM pass.

Depends on D154137

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D154147

Added: 
    mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt
    mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp
    mlir/test/Dialect/GPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
index 0802e9025a8181..2e2a084413fa57 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
@@ -39,4 +39,29 @@ def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
 def GPUObjectArrayAttr :
   TypedArrayAttrBase<GPU_ObjectAttr, "an array of GPU object attributes">;
 
+//===----------------------------------------------------------------------===//
+// GPU offloading LLVM translation handler attributes.
+//===----------------------------------------------------------------------===//
+
+def GPU_SelectObjectAttr : GPU_Attr<"SelectObject", "select_object", [
+      OffloadingTranslationAttrTrait
+    ]> {
+  let description = [{
+    This GPU offloading handler selects a single GPU object for embedding. The
+    object is selected based on the `target` parameter, this parameter can be
+    either a number -i.e. selects the ith-target, or the target itself -i.e.
+    searches for the specified target in the object array.
+
+    The first object in a `gpu.binary` operation is selected if no target is
+    specified.
+  }];
+  let parameters = (ins
+    OptionalParameter<"Attribute", "Target to select for embedding.">:$target
+  );
+  let assemblyFormat = [{
+    (`<` $target^ `>`)?
+  }];
+  let genVerifyDecl = 1;
+}
+
 #endif // GPU_COMPILATION_ATTRS

diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 2f8cb968f18b00..5921df9fa5e8ce 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1093,15 +1093,19 @@ def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins
      - An optional attribute implementing the offloading LLVM translation interface.
      - An array of GPU object attributes.
 
-    During translation into LLVM, the offloading attribute will be called
-    for translating GPU binary and launch operations into LLVM instructions. If
-    no attribute is provided, the default handler selects the first object from
-    the array and embeds it as a string.
+    During translation, the offloading attribute will be called for translating
+    GPU `binary` and `launch_func` operations. The default offloading handler is:
+    `#gpu.select_object`, this handler selects the first object from the array
+    and embeds it as a string.
 
     Examples:
     ```
+      // Selects the first object.
       gpu.binary @myobject [#gpu.object<...>, #gpu.object<...>]
+      // Uses the `#foo.my_handler` for handling the binary during translation.
       gpu.binary @myobject <#foo.my_handler> [#gpu.object<...>, #gpu.object<...>]
+      // Selects the object with the `#rocdl.target` target attribute.
+      gpu.binary @myobject <#gpu.select_object<#rocdl.target>> [#gpu.object<...>, #gpu.object<#rocdl.target, ...>]
     ```
   }];
   let builders = [
@@ -1114,7 +1118,7 @@ def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins
   ];
   let skipDefaultBuilders = 1;
   let assemblyFormat = [{
-    $sym_name (`<` $offloadingHandler ^ `>`)? attr-dict $objects
+    $sym_name custom<OffloadingHandler>($offloadingHandler) attr-dict $objects
   }];
 }
 

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h
index 3ba7a18bed8c6e..b42158d9b1e57c 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h
@@ -26,6 +26,13 @@ void registerGPUDialectTranslation(DialectRegistry &registry);
 /// associated with the given context.
 void registerGPUDialectTranslation(MLIRContext &context);
 
+namespace gpu {
+/// Registers the offloading LLVM translation interfaces for
+/// `gpu.select_object`.
+void registerOffloadingLLVMTranslationInterfacesExternalModels(
+    mlir::DialectRegistry &registry);
+} // namespace gpu
+
 } // namespace mlir
 
 #endif // MLIR_TARGET_LLVMIR_DIALECT_GPU_GPUTOLLVMIRTRANSLATION_H

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 3fcc816d09a7c6..7c1d8b0e1abb73 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1652,7 +1652,10 @@ void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
   result.attributes.push_back(builder.getNamedAttr(
       SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
   properties.objects = objects;
-  properties.offloadingHandler = offloadingHandler;
+  if (offloadingHandler)
+    properties.offloadingHandler = offloadingHandler;
+  else
+    properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
 }
 
 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
@@ -1661,6 +1664,25 @@ void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
         objects.size() > 0 ? builder.getArrayAttr(objects) : ArrayAttr());
 }
 
+static ParseResult parseOffloadingHandler(OpAsmParser &parser,
+                                          Attribute &offloadingHandler) {
+  if (succeeded(parser.parseOptionalLess())) {
+    if (parser.parseAttribute(offloadingHandler))
+      return failure();
+    if (parser.parseGreater())
+      return failure();
+  }
+  if (!offloadingHandler)
+    offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
+  return success();
+}
+
+static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op,
+                                   Attribute offloadingHandler) {
+  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
+    printer << '<' << offloadingHandler << '>';
+}
+
 //===----------------------------------------------------------------------===//
 // GPUMemcpyOp
 //===----------------------------------------------------------------------===//
@@ -1932,6 +1954,27 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyDimOfAllocOp>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// GPU select object attribute
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                              Attribute target) {
+  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
+  if (target) {
+    if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
+      if (intAttr.getInt() < 0) {
+        return emitError() << "The object index must be positive.";
+      }
+    } else if (!(::mlir::isa<TargetAttrInterface>(target))) {
+      return emitError()
+             << "The target attribute must be a GPU Target attribute.";
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GPU target options
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt
index a0811228b38419..11816ff5c2c1f1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_translation_library(MLIRGPUToLLVMIRTranslation
   GPUToLLVMIRTranslation.cpp
+  SelectObjectAttr.cpp
 
   LINK_COMPONENTS
   Core

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp
index a12316112e095d..3e677bcc5d7f6b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp
@@ -36,6 +36,7 @@ void mlir::registerGPUDialectTranslation(DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
     dialect->addInterfaces<GPUDialectLLVMIRTranslationInterface>();
   });
+  gpu::registerOffloadingLLVMTranslationInterfacesExternalModels(registry);
 }
 
 void mlir::registerGPUDialectTranslation(MLIRContext &context) {

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
new file mode 100644
index 00000000000000..dd19fafa6ecb46
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
@@ -0,0 +1,371 @@
+//===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the `OffloadingLLVMTranslationAttrInterface` for the
+// `SelectObject` attribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+namespace {
+// Implementation of the `OffloadingLLVMTranslationAttrInterface` model.
+class SelectObjectAttrImpl
+    : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
+          SelectObjectAttrImpl> {
+public:
+  // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
+  // global binary string.
+  LogicalResult embedBinary(Attribute attribute, Operation *operation,
+                            llvm::IRBuilderBase &builder,
+                            LLVM::ModuleTranslation &moduleTranslation) const;
+
+  // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting
+  // in a kernel launch call.
+  LogicalResult launchKernel(Attribute attribute,
+                             Operation *launchFuncOperation,
+                             Operation *binaryOperation,
+                             llvm::IRBuilderBase &builder,
+                             LLVM::ModuleTranslation &moduleTranslation) const;
+};
+// Returns an identifier for the global string holding the binary.
+std::string getBinaryIdentifier(StringRef binaryName) {
+  return binaryName.str() + "_bin_cst";
+}
+} // namespace
+
+void mlir::gpu::registerOffloadingLLVMTranslationInterfacesExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
+    SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
+  });
+}
+
+LogicalResult SelectObjectAttrImpl::embedBinary(
+    Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation) const {
+  assert(operation && "The binary operation must be non null.");
+  if (!operation)
+    return failure();
+
+  auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
+  if (!op) {
+    operation->emitError("Operation must be a GPU binary.");
+    return failure();
+  }
+
+  ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
+
+  // Obtain the index of the object to select.
+  int64_t index = -1;
+  if (Attribute target = cast<gpu::SelectObjectAttr>(attribute).getTarget()) {
+    // If the target attribute is a number it is the index. Otherwise compare
+    // the attribute to every target inside the object array to find the index.
+    if (auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
+      index = indexAttr.getInt();
+    } else {
+      for (auto [i, attr] : llvm::enumerate(objects)) {
+        auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
+        if (obj.getTarget() == target) {
+          index = i;
+        }
+      }
+    }
+  } else {
+    // If the target attribute is null then it's selecting the first object in
+    // the object array.
+    index = 0;
+  }
+
+  if (index < 0 || index >= static_cast<int64_t>(objects.size())) {
+    op->emitError("The requested target object couldn't be found.");
+    return failure();
+  }
+  auto object = mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
+
+  llvm::Module *module = moduleTranslation.getLLVMModule();
+
+  // Embed the object as a global string.
+  llvm::Constant *binary = llvm::ConstantDataArray::getString(
+      builder.getContext(), object.getObject().getValue(), false);
+  llvm::GlobalVariable *serializedObj =
+      new llvm::GlobalVariable(*module, binary->getType(), true,
+                               llvm::GlobalValue::LinkageTypes::InternalLinkage,
+                               binary, getBinaryIdentifier(op.getName()));
+  serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
+  serializedObj->setAlignment(llvm::MaybeAlign(8));
+  serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
+  return success();
+}
+
+namespace llvm {
+namespace {
+class LaunchKernel {
+public:
+  LaunchKernel(Module &module, IRBuilderBase &builder,
+               mlir::LLVM::ModuleTranslation &moduleTranslation);
+  // Get the kernel launch callee.
+  FunctionCallee getKernelLaunchFn();
+
+  // Get the module function callee.
+  FunctionCallee getModuleFunctionFn();
+
+  // Get the module load callee.
+  FunctionCallee getModuleLoadFn();
+
+  // Get the module unload callee.
+  FunctionCallee getModuleUnloadFn();
+
+  // Get the stream create callee.
+  FunctionCallee getStreamCreateFn();
+
+  // Get the stream destroy callee.
+  FunctionCallee getStreamDestroyFn();
+
+  // Get the stream sync callee.
+  FunctionCallee getStreamSyncFn();
+
+  // Ger or create the function name global string.
+  Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
+
+  // Create the void* kernel array for passing the arguments.
+  Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
+
+  // Create the full kernel launch.
+  mlir::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op);
+
+private:
+  Module &module;
+  IRBuilderBase &builder;
+  mlir::LLVM::ModuleTranslation &moduleTranslation;
+  Type *i32Ty{};
+  Type *voidTy{};
+  Type *intPtrTy{};
+  PointerType *ptrTy{};
+};
+} // namespace
+} // namespace llvm
+
+LogicalResult SelectObjectAttrImpl::launchKernel(
+    Attribute attribute, Operation *launchFuncOperation,
+    Operation *binaryOperation, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation) const {
+
+  assert(launchFuncOperation && "The launch func operation must be non null.");
+  if (!launchFuncOperation)
+    return failure();
+
+  auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
+  if (!launchFuncOp) {
+    launchFuncOperation->emitError("Operation must be a GPU launch func Op.");
+    return failure();
+  }
+
+  return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder,
+                            moduleTranslation)
+      .createKernelLaunch(launchFuncOp);
+}
+
+llvm::LaunchKernel::LaunchKernel(
+    Module &module, IRBuilderBase &builder,
+    mlir::LLVM::ModuleTranslation &moduleTranslation)
+    : module(module), builder(builder), moduleTranslation(moduleTranslation) {
+  i32Ty = builder.getInt32Ty();
+  ptrTy = builder.getPtrTy(0);
+  voidTy = builder.getVoidTy();
+  intPtrTy = builder.getIntPtrTy(module.getDataLayout());
+}
+
+llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
+  return module.getOrInsertFunction(
+      "mgpuLaunchKernel",
+      FunctionType::get(
+          voidTy,
+          ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
+                            intPtrTy, intPtrTy, i32Ty, ptrTy, ptrTy, ptrTy}),
+          false));
+}
+
+llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
+  return module.getOrInsertFunction(
+      "mgpuModuleGetFunction",
+      FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
+}
+
+llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
+  return module.getOrInsertFunction(
+      "mgpuModuleLoad",
+      FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy}), false));
+}
+
+llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
+  return module.getOrInsertFunction(
+      "mgpuModuleUnload",
+      FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
+}
+
+llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
+  return module.getOrInsertFunction("mgpuStreamCreate",
+                                    FunctionType::get(ptrTy, false));
+}
+
+llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
+  return module.getOrInsertFunction(
+      "mgpuStreamDestroy",
+      FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
+}
+
+llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
+  return module.getOrInsertFunction(
+      "mgpuStreamSynchronize",
+      FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
+}
+
+// Generates an LLVM IR dialect global that contains the name of the given
+// kernel function as a C string, and returns a pointer to its beginning.
+llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
+                                                         StringRef kernelName) {
+  std::string globalName =
+      std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName));
+
+  if (GlobalVariable *gv = module.getGlobalVariable(globalName))
+    return gv;
+
+  return builder.CreateGlobalString(kernelName, globalName);
+}
+
+// Creates a struct containing all kernel parameters on the stack and returns
+// an array of type-erased pointers to the fields of the struct. The array can
+// then be passed to the CUDA / ROCm (HIP) kernel launch calls.
+// The generated code is essentially as follows:
+//
+// %struct = alloca(sizeof(struct { Parameters... }))
+// %array = alloca(NumParameters * sizeof(void *))
+// for (i : [0, NumParameters))
+//   %fieldPtr = llvm.getelementptr %struct[0, i]
+//   llvm.store parameters[i], %fieldPtr
+//   %elementPtr = llvm.getelementptr %array[i]
+//   llvm.store %fieldPtr, %elementPtr
+// return %array
+llvm::Value *
+llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
+  SmallVector<Value *> args =
+      moduleTranslation.lookupValues(op.getKernelOperands());
+  SmallVector<Type *> structTypes(args.size(), nullptr);
+
+  for (auto [i, arg] : llvm::enumerate(args))
+    structTypes[i] = arg->getType();
+
+  Type *structTy = StructType::create(module.getContext(), structTypes);
+  Value *argStruct = builder.CreateAlloca(structTy, 0u);
+  Value *argArray = builder.CreateAlloca(
+      ptrTy, ConstantInt::get(intPtrTy, structTypes.size()));
+
+  for (auto [i, arg] : enumerate(args)) {
+    Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
+    builder.CreateStore(arg, structMember);
+    Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
+    builder.CreateStore(structMember, arrayMember);
+  }
+  return argArray;
+}
+
+// Emits LLVM IR to launch a kernel function:
+// %0 = call %binarygetter
+// %1 = call %moduleLoad(%0)
+// %2 = <see generateKernelNameConstant>
+// %3 = call %moduleGetFunction(%1, %2)
+// %4 = call %streamCreate()
+// %5 = <see generateParamsArray>
+// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
+// call %streamSynchronize(%4)
+// call %streamDestroy(%4)
+// call %moduleUnload(%1)
+mlir::LogicalResult
+llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op) {
+  auto llvmValue = [&](mlir::Value value) -> Value * {
+    Value *v = moduleTranslation.lookupValue(value);
+    assert(v && "Value has not been translated.");
+    return v;
+  };
+
+  // Get grid dimensions.
+  mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
+  Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y),
+        *gz = llvmValue(grid.z);
+
+  // Get block dimensions.
+  mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
+  Value *bx = llvmValue(block.x), *by = llvmValue(block.y),
+        *bz = llvmValue(block.z);
+
+  // Get dynamic shared memory size.
+  Value *dynamicMemorySize = nullptr;
+  if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
+    dynamicMemorySize = llvmValue(dynSz);
+  else
+    dynamicMemorySize = ConstantInt::get(i32Ty, 0);
+
+  // Create the argument array.
+  Value *argArray = createKernelArgArray(op);
+
+  // Load the kernel module.
+  StringRef moduleName = op.getKernelModuleName().getValue();
+  std::string binaryIdentifier = getBinaryIdentifier(moduleName);
+  Value *binary = module.getGlobalVariable(binaryIdentifier, true);
+  if (!binary)
+    return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
+  Value *moduleObject = builder.CreateCall(getModuleLoadFn(), {binary});
+
+  // Load the kernel function.
+  Value *moduleFunction = builder.CreateCall(
+      getModuleFunctionFn(),
+      {moduleObject,
+       getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
+
+  // Get the stream to use for execution. If there's no async object then create
+  // a stream to make a synchronous kernel launch.
+  Value *stream = nullptr;
+  bool handleStream = false;
+  if (mlir::Value asyncObject = op.getAsyncObject()) {
+    stream = llvmValue(asyncObject);
+  } else {
+    handleStream = true;
+    stream = builder.CreateCall(getStreamCreateFn(), {});
+  }
+
+  // Create the launch call.
+  Value *nullPtr = ConstantPointerNull::get(ptrTy);
+  builder.CreateCall(
+      getKernelLaunchFn(),
+      ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, bz,
+                         dynamicMemorySize, stream, argArray, nullPtr}));
+
+  // Sync & destroy the stream, for synchronous launches.
+  if (handleStream) {
+    builder.CreateCall(getStreamSyncFn(), {stream});
+    builder.CreateCall(getStreamDestroyFn(), {stream});
+  }
+
+  // Unload the kernel module.
+  builder.CreateCall(getModuleUnloadFn(), {moduleObject});
+
+  return success();
+}

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 0ef7cfb854e3e4..b314a768a08963 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -121,6 +121,12 @@ module attributes {gpu.container_module} {
     }
   }
 
+  gpu.binary @binary_1 [#gpu.object<#nvvm.target, "">]
+
+  gpu.binary @binary_2 <#gpu.select_object<#nvvm.target<chip = "sm_90">>> [#gpu.object<#nvvm.target, "">, #gpu.object<#nvvm.target<chip = "sm_90">, "">]
+
+  gpu.binary @binary_3 <#gpu.select_object<1>> [#gpu.object<#nvvm.target, "">, #gpu.object<#nvvm.target<chip = "sm_90">, "">]
+
   func.func private @two_value_generator() -> (f32, memref<?xf32, 1>)
 
   func.func @foo() {
@@ -150,6 +156,9 @@ module attributes {gpu.container_module} {
     // CHECK: gpu.launch_func @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i32 args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
     gpu.launch_func @kernels::@kernel_1 blocks in (%c0, %c0, %c0) threads in (%c0, %c0, %c0) : i32 args(%0 : f32, %1 : memref<?xf32, 1>)
 
+    // CHECK: gpu.launch_func @binary_1::@kernel blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i32 args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
+    gpu.launch_func @binary_1::@kernel blocks in (%c0, %c0, %c0) threads in (%c0, %c0, %c0) : i32 args(%0 : f32, %1 : memref<?xf32, 1>)
+
     // CHECK: %[[VALUES:.*]]:2 = call
     %values:2 = func.call @two_value_generator() : () -> (f32, memref<?xf32, 1>)
     // CHECK: gpu.launch_func @kernels::@kernel_1 {{.*}} args(%[[VALUES]]#0 : f32, %[[VALUES]]#1 : memref<?xf32, 1>)


        


More information about the Mlir-commits mailing list