[Mlir-commits] [mlir] [mlir][gpu] Update LaunchFuncOp lowering in GPU to LLVM (PR #94991)

Fabian Mora llvmlistbot at llvm.org
Mon Jun 10 08:01:43 PDT 2024


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/94991

>From 04ea85f2cb8365c1aed8412b8e5ec0daad8ba149 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 10 Jun 2024 14:33:48 +0000
Subject: [PATCH] [mlir][gpu] Update LaunchFuncOp lowering in GPU to LLVM

This patch updates the lowering of `LaunchFuncOp` in GPU to LLVM to only
legalize the operation. It also removes all remaining uses of the old
compilation infrastructure.
---
 .../mlir/Conversion/GPUCommon/GPUCommonPass.h |  10 +-
 mlir/include/mlir/Conversion/Passes.td        |   6 +-
 .../GPUCommon/GPUToLLVMConversion.cpp         | 345 +++---------------
 ...ower-launch-func-to-gpu-runtime-calls.mlir |  73 +---
 4 files changed, 66 insertions(+), 368 deletions(-)

diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
index 48b7835ae5fca..2d5e9d27c5bdf 100644
--- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
@@ -46,9 +46,6 @@ class LLVMDialect;
 #define GEN_PASS_DECL_GPUTOLLVMCONVERSIONPASS
 #include "mlir/Conversion/Passes.h.inc"
 
-using OwnedBlob = std::unique_ptr<std::vector<char>>;
-using BlobGenerator =
-    std::function<OwnedBlob(const std::string &, Location, StringRef)>;
 using LoweringCallback = std::function<std::unique_ptr<llvm::Module>(
     Operation *, llvm::LLVMContext &, StringRef)>;
 
@@ -66,10 +63,9 @@ struct FunctionCallBuilder {
 
 /// Collect a set of patterns to convert from the GPU dialect to LLVM and
 /// populate converter for gpu types.
-void populateGpuToLLVMConversionPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    StringRef gpuBinaryAnnotation = {}, bool kernelBarePtrCallConv = false,
-    SymbolTable *cachedModuleTable = nullptr);
+void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                         RewritePatternSet &patterns,
+                                         bool kernelBarePtrCallConv = false);
 
 /// A function that maps a MemorySpace enum to a target-specific integer value.
 using MemorySpaceMapping = std::function<unsigned(gpu::AddressSpace)>;
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index eb58f4adc31d3..db67d6a5ff128 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -478,11 +478,7 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
            /*default=*/"false",
              "Use bare pointers to pass memref arguments to kernels. "
              "The kernel must use the same setting for this option."
-           >,
-    Option<"gpuBinaryAnnotation", "gpu-binary-annotation", "std::string",
-               /*default=*/"gpu::getDefaultGpuBinaryAnnotation()",
-               "Annotation attribute string for GPU binary"
-               >
+           >
   ];
 
   let dependentDialects = [
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 82bfa9514a884..92b28ff9c5873 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -49,8 +49,6 @@ namespace mlir {
 
 using namespace mlir;
 
-static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
-
 namespace {
 class GpuToLLVMConversionPass
     : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
@@ -97,36 +95,6 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   Type llvmIntPtrType = IntegerType::get(
       context, this->getTypeConverter()->getPointerBitwidth(0));
 
-  FunctionCallBuilder moduleLoadCallBuilder = {
-      "mgpuModuleLoad",
-      llvmPointerType /* void *module */,
-      {llvmPointerType /* void *cubin */, llvmInt64Type /* size_t size */}};
-  FunctionCallBuilder moduleUnloadCallBuilder = {
-      "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
-  FunctionCallBuilder moduleGetFunctionCallBuilder = {
-      "mgpuModuleGetFunction",
-      llvmPointerType /* void *function */,
-      {
-          llvmPointerType, /* void *module */
-          llvmPointerType  /* char *name   */
-      }};
-  FunctionCallBuilder launchKernelCallBuilder = {
-      "mgpuLaunchKernel",
-      llvmVoidType,
-      {
-          llvmPointerType, /* void* f */
-          llvmIntPtrType,  /* intptr_t gridXDim */
-          llvmIntPtrType,  /* intptr_t gridyDim */
-          llvmIntPtrType,  /* intptr_t gridZDim */
-          llvmIntPtrType,  /* intptr_t blockXDim */
-          llvmIntPtrType,  /* intptr_t blockYDim */
-          llvmIntPtrType,  /* intptr_t blockZDim */
-          llvmInt32Type,   /* unsigned int sharedMemBytes */
-          llvmPointerType, /* void *hstream */
-          llvmPointerType, /* void **kernelParams */
-          llvmPointerType, /* void **extra */
-          llvmInt64Type    /* size_t paramsCount */
-      }};
   FunctionCallBuilder streamCreateCallBuilder = {
       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
   FunctionCallBuilder streamDestroyCallBuilder = {
@@ -451,55 +419,21 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// A rewrite patter to convert gpu.launch_func operations into a sequence of
-/// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
-///
-/// In essence, a gpu.launch_func operations gets compiled into the following
-/// sequence of runtime calls:
-///
-/// * moduleLoad        -- loads the module given the cubin / hsaco data
-/// * moduleGetFunction -- gets a handle to the actual kernel function
-/// * getStreamHelper   -- initializes a new compute stream on GPU
-/// * launchKernel      -- launches the kernel on a stream
-/// * streamSynchronize -- waits for operations on the stream to finish
-///
-/// Intermediate data structures are allocated on the stack.
-class ConvertLaunchFuncOpToGpuRuntimeCallPattern
+/// A rewrite patter to legalize gpu.launch_func with LLVM types.
+class LegalizeLaunchFuncOpPattern
     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
 public:
-  ConvertLaunchFuncOpToGpuRuntimeCallPattern(
-      const LLVMTypeConverter &typeConverter, StringRef gpuBinaryAnnotation,
-      bool kernelBarePtrCallConv, SymbolTable *cachedModuleTable)
+  LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
+                              bool kernelBarePtrCallConv)
       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
-        gpuBinaryAnnotation(gpuBinaryAnnotation),
-        kernelBarePtrCallConv(kernelBarePtrCallConv),
-        cachedModuleTable(cachedModuleTable) {}
+        kernelBarePtrCallConv(kernelBarePtrCallConv) {}
 
 private:
-  Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
-                            OpBuilder &builder) const;
-  Value generateKernelNameConstant(StringRef moduleName, StringRef name,
-                                   Location loc, OpBuilder &builder) const;
-
   LogicalResult
   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 
-  llvm::SmallString<32> gpuBinaryAnnotation;
   bool kernelBarePtrCallConv;
-  SymbolTable *cachedModuleTable;
-};
-
-class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
-  using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
-                                PatternRewriter &rewriter) const override {
-    // GPU kernel modules are no longer necessary since we have a global
-    // constant with the CUBIN, or HSACO data.
-    rewriter.eraseOp(op);
-    return success();
-  }
 };
 
 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
@@ -587,7 +521,6 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
 
 void GpuToLLVMConversionPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  SymbolTable symbolTable = SymbolTable(getOperation());
   LowerToLLVMOptions options(context);
   options.useBarePtrCallConv = hostBarePtrCallConv;
   RewritePatternSet patterns(context);
@@ -604,30 +537,20 @@ void GpuToLLVMConversionPass::runOnOperation() {
     iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
   }
 
-  // Preserve GPU modules if they have target attributes.
-  target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
-      [](gpu::GPUModuleOp module) -> bool {
-        return module.getTargetsAttr() != nullptr;
-      });
-  // Accept as legal LaunchFuncOps if they refer to GPU Modules with targets and
-  // the operands have been lowered.
+  // Preserve GPU modules and binaries. Modules are preserved as they can be
+  // converted later by `gpu-module-to-binary`.
+  target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
+  // Accept as legal LaunchFuncOps if the operands have been lowered.
   target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
-      [&](gpu::LaunchFuncOp op) -> bool {
-        auto module =
-            symbolTable.lookup<gpu::GPUModuleOp>(op.getKernelModuleName());
-        return converter.isLegal(op->getOperandTypes()) &&
-               converter.isLegal(op->getResultTypes()) &&
-               (module && module.getTargetsAttr() &&
-                !module.getTargetsAttr().empty());
-      });
+      [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
 
   // These aren't covered by the ConvertToLLVMPatternInterface right now.
   populateVectorToLLVMConversionPatterns(converter, patterns);
   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
                                                     target);
-  populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
-                                      kernelBarePtrCallConv, &symbolTable);
+  populateGpuToLLVMConversionPatterns(converter, patterns,
+                                      kernelBarePtrCallConv);
 
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
@@ -1002,100 +925,8 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
-// Creates a struct containing all kernel parameters on the stack and returns
-// an array of type-erased pointers to the fields of the struct. The array can
-// then be passed to the CUDA / ROCm (HIP) kernel launch calls.
-// The generated code is essentially as follows:
-//
-// %struct = alloca(sizeof(struct { Parameters... }))
-// %array = alloca(NumParameters * sizeof(void *))
-// for (i : [0, NumParameters))
-//   %fieldPtr = llvm.getelementptr %struct[0, i]
-//   llvm.store parameters[i], %fieldPtr
-//   %elementPtr = llvm.getelementptr %array[i]
-//   llvm.store %fieldPtr, %elementPtr
-// return %array
-Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
-    gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
-  auto loc = launchOp.getLoc();
-  auto numKernelOperands = launchOp.getNumKernelOperands();
-  // Note: If `useBarePtrCallConv` is set in the type converter's options,
-  // the value of `kernelBarePtrCallConv` will be ignored.
-  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
-      loc, launchOp.getOperands().take_back(numKernelOperands),
-      adaptor.getOperands().take_back(numKernelOperands), builder,
-      /*useBarePtrCallConv=*/kernelBarePtrCallConv);
-  auto numArguments = arguments.size();
-  SmallVector<Type, 4> argumentTypes;
-  argumentTypes.reserve(numArguments);
-  for (auto argument : arguments)
-    argumentTypes.push_back(argument.getType());
-  auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
-                                                           argumentTypes);
-  auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 1);
-  auto structPtr =
-      builder.create<LLVM::AllocaOp>(loc, llvmPointerType, structType, one,
-                                     /*alignment=*/0);
-  auto arraySize =
-      builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, numArguments);
-  auto arrayPtr = builder.create<LLVM::AllocaOp>(
-      loc, llvmPointerType, llvmPointerType, arraySize, /*alignment=*/0);
-  for (const auto &en : llvm::enumerate(arguments)) {
-    const auto index = static_cast<int32_t>(en.index());
-    Value fieldPtr =
-        builder.create<LLVM::GEPOp>(loc, llvmPointerType, structType, structPtr,
-                                    ArrayRef<LLVM::GEPArg>{0, index});
-    builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
-    auto elementPtr =
-        builder.create<LLVM::GEPOp>(loc, llvmPointerType, llvmPointerType,
-                                    arrayPtr, ArrayRef<LLVM::GEPArg>{index});
-    builder.create<LLVM::StoreOp>(loc, fieldPtr, elementPtr);
-  }
-  return arrayPtr;
-}
-
-// Generates an LLVM IR dialect global that contains the name of the given
-// kernel function as a C string, and returns a pointer to its beginning.
-// The code is essentially:
-//
-// llvm.global constant @kernel_name("function_name\00")
-// func(...) {
-//   %0 = llvm.addressof @kernel_name
-//   %1 = llvm.constant (0 : index)
-//   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
-// }
-Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
-    StringRef moduleName, StringRef name, Location loc,
-    OpBuilder &builder) const {
-  // Make sure the trailing zero is included in the constant.
-  std::vector<char> kernelName(name.begin(), name.end());
-  kernelName.push_back('\0');
-
-  std::string globalName =
-      std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
-  return LLVM::createGlobalString(
-      loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
-      LLVM::Linkage::Internal);
-}
-
-// Emits LLVM IR to launch a kernel function. Expects the module that contains
-// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
-// hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
-//
-// %0 = call %binarygetter
-// %1 = call %moduleLoad(%0)
-// %2 = <see generateKernelNameConstant>
-// %3 = call %moduleGetFunction(%1, %2)
-// %4 = call %streamCreate()
-// %5 = <see generateParamsArray>
-// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
-// call %streamSynchronize(%4)
-// call %streamDestroy(%4)
-// call %moduleUnload(%1)
-//
-// If the op is async, the stream corresponds to the (single) async dependency
-// as well as the async token the op produces.
-LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
+// Legalize the op's operands.
+LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
@@ -1114,123 +945,37 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
 
   Location loc = launchOp.getLoc();
 
-  // Create an LLVM global with CUBIN extracted from the kernel annotation and
-  // obtain a pointer to the first byte in it.
-  gpu::GPUModuleOp kernelModule;
-  if (cachedModuleTable)
-    kernelModule = cachedModuleTable->lookup<gpu::GPUModuleOp>(
-        launchOp.getKernelModuleName());
-  else
-    kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
-        launchOp, launchOp.getKernelModuleName());
-  assert(kernelModule && "expected a kernel module");
-
-  // If the module has Targets then just update the op operands.
-  if (ArrayAttr targets = kernelModule.getTargetsAttr()) {
-    Value stream = Value();
-    if (!adaptor.getAsyncDependencies().empty())
-      stream = adaptor.getAsyncDependencies().front();
-    // If the async keyword is present and there are no dependencies, then a
-    // stream must be created to pass to subsequent operations.
-    else if (launchOp.getAsyncToken())
-      stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
-
-    // Lower the kernel operands to match kernel parameters.
-    // Note: If `useBarePtrCallConv` is set in the type converter's options,
-    // the value of `kernelBarePtrCallConv` will be ignored.
-    SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
-        loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
-        rewriter, /*useBarePtrCallConv=*/kernelBarePtrCallConv);
-
-    std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
-    if (launchOp.hasClusterSize()) {
-      clusterSize =
-          gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
-                          adaptor.getClusterSizeZ()};
-    }
-    rewriter.create<gpu::LaunchFuncOp>(
-        launchOp.getLoc(), launchOp.getKernelAttr(),
-        gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
-                        adaptor.getGridSizeZ()},
-        gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
-                        adaptor.getBlockSizeZ()},
-        adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
-    if (launchOp.getAsyncToken())
-      rewriter.replaceOp(launchOp, {stream});
-    else
-      rewriter.eraseOp(launchOp);
-    return success();
-  }
+  Value stream = Value();
+  if (!adaptor.getAsyncDependencies().empty())
+    stream = adaptor.getAsyncDependencies().front();
+  // If the async keyword is present and there are no dependencies, then a
+  // stream must be created to pass to subsequent operations.
+  else if (launchOp.getAsyncToken())
+    stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
+  // Lower the kernel operands to match kernel parameters.
+  // Note: If `useBarePtrCallConv` is set in the type converter's options,
+  // the value of `kernelBarePtrCallConv` will be ignored.
+  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
+      loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
+      /*useBarePtrCallConv=*/kernelBarePtrCallConv);
 
-  auto binaryAttr =
-      kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
-  if (!binaryAttr) {
-    kernelModule.emitOpError()
-        << "missing " << gpuBinaryAnnotation << " attribute";
-    return failure();
+  std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
+  if (launchOp.hasClusterSize()) {
+    clusterSize =
+        gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
+                        adaptor.getClusterSizeZ()};
   }
-
-  SmallString<128> nameBuffer(kernelModule.getName());
-  nameBuffer.append(kGpuBinaryStorageSuffix);
-  Value data =
-      LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
-                               binaryAttr.getValue(), LLVM::Linkage::Internal);
-
-  // Pass the binary size. SPIRV requires binary size.
-  auto gpuBlob = binaryAttr.getValue();
-  auto gpuBlobSize = rewriter.create<mlir::LLVM::ConstantOp>(
-      loc, llvmInt64Type,
-      mlir::IntegerAttr::get(llvmInt64Type,
-                             static_cast<int64_t>(gpuBlob.size())));
-
-  auto module =
-      moduleLoadCallBuilder.create(loc, rewriter, {data, gpuBlobSize});
-
-  // Pass the count of the parameters to runtime wrappers
-  auto paramsCount = rewriter.create<mlir::LLVM::ConstantOp>(
-      loc, llvmInt64Type,
-      mlir::IntegerAttr::get(
-          llvmInt64Type,
-          static_cast<int64_t>(launchOp.getNumKernelOperands())));
-
-  // Get the function from the module. The name corresponds to the name of
-  // the kernel function.
-  auto kernelName = generateKernelNameConstant(
-      launchOp.getKernelModuleName().getValue(),
-      launchOp.getKernelName().getValue(), loc, rewriter);
-  auto function = moduleGetFunctionCallBuilder.create(
-      loc, rewriter, {module.getResult(), kernelName});
-  Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
-  Value stream =
-      adaptor.getAsyncDependencies().empty()
-          ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
-          : adaptor.getAsyncDependencies().front();
-  // Create array of pointers to kernel arguments.
-  auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
-  auto nullpointer = rewriter.create<LLVM::ZeroOp>(loc, llvmPointerType);
-  Value dynamicSharedMemorySize = launchOp.getDynamicSharedMemorySize()
-                                      ? launchOp.getDynamicSharedMemorySize()
-                                      : zero;
-  launchKernelCallBuilder.create(
-      loc, rewriter,
-      {function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(),
-       adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
-       adaptor.getBlockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
-       /*extra=*/nullpointer, paramsCount});
-
-  if (launchOp.getAsyncToken()) {
-    // Async launch: make dependent ops use the same stream.
+  rewriter.create<gpu::LaunchFuncOp>(
+      launchOp.getLoc(), launchOp.getKernelAttr(),
+      gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
+                      adaptor.getGridSizeZ()},
+      gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
+                      adaptor.getBlockSizeZ()},
+      adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
+  if (launchOp.getAsyncToken())
     rewriter.replaceOp(launchOp, {stream});
-  } else {
-    // Synchronize with host and destroy stream. This must be the stream created
-    // above (with no other uses) because we check that the synchronous version
-    // does not have any async dependencies.
-    streamSynchronizeCallBuilder.create(loc, rewriter, stream);
-    streamDestroyCallBuilder.create(loc, rewriter, stream);
+  else
     rewriter.eraseOp(launchOp);
-  }
-  moduleUnloadCallBuilder.create(loc, rewriter, module.getResult());
-
   return success();
 }
 
@@ -1978,9 +1723,7 @@ LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
 
 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns,
-                                               StringRef gpuBinaryAnnotation,
-                                               bool kernelBarePtrCallConv,
-                                               SymbolTable *cachedModuleTable) {
+                                               bool kernelBarePtrCallConv) {
   addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
   addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
   addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
@@ -2017,7 +1760,5 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
                ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
                ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
-  patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
-      converter, gpuBinaryAnnotation, kernelBarePtrCallConv, cachedModuleTable);
-  patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
+  patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
 }
diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
index c0b05ef086033..6c5c1e09c0eb5 100644
--- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
@@ -1,15 +1,8 @@
-// RUN: mlir-opt %s --gpu-to-llvm="gpu-binary-annotation=nvvm.cubin" -split-input-file | FileCheck %s
-// RUN: mlir-opt %s --gpu-to-llvm="gpu-binary-annotation=rocdl.hsaco" -split-input-file | FileCheck %s --check-prefix=ROCDL
+// RUN: mlir-opt %s --gpu-to-llvm -split-input-file | FileCheck %s
 
 module attributes {gpu.container_module} {
-
-  // CHECK: llvm.mlir.global internal constant @[[KERNEL_NAME:.*]]("kernel\00")
-  // CHECK: llvm.mlir.global internal constant @[[GLOBAL:.*]]("CUBIN")
-  // ROCDL: llvm.mlir.global internal constant @[[GLOBAL:.*]]("HSACO")
-
-  gpu.module @kernel_module attributes {
-      nvvm.cubin = "CUBIN", rocdl.hsaco = "HSACO"
-  } {
+  // CHECK: gpu.module
+  gpu.module @kernel_module [#nvvm.target] {
     llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr,
         %arg2: !llvm.ptr, %arg3: i64, %arg4: i64,
         %arg5: i64) attributes {gpu.kernel} {
@@ -18,9 +11,17 @@ module attributes {gpu.container_module} {
   }
 
   func.func @foo(%buffer: memref<?xf32>) {
+  // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
+  // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
+  // CHECK: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32
     %c8 = arith.constant 8 : index
     %c32 = arith.constant 32 : i32
     %c256 = arith.constant 256 : i32
+
+  // CHECK: gpu.launch_func @kernel_module::@kernel
+  // CHECK: blocks in ([[C8]], [[C8]], [[C8]]) threads in ([[C8]], [[C8]], [[C8]]) : i64
+  // CHECK: dynamic_shared_memory_size [[C256]]
+  // CHECK: args([[C32]] : i32, %{{.*}} : !llvm.ptr, %{{.*}} : !llvm.ptr, %{{.*}} : i64, %{{.*}} : i64, %{{.*}} : i64)
     gpu.launch_func @kernel_module::@kernel
         blocks in (%c8, %c8, %c8)
         threads in (%c8, %c8, %c8)
@@ -28,46 +29,13 @@ module attributes {gpu.container_module} {
         args(%c32 : i32, %buffer : memref<?xf32>)
     return
   }
-
-  // CHECK-DAG: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32
-  // CHECK-DAG: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
-  // CHECK: [[ADDRESSOF:%.*]] = llvm.mlir.addressof @[[GLOBAL]]
-  // CHECK: [[BINARY:%.*]] = llvm.getelementptr [[ADDRESSOF]]{{\[}}0, 0]
-  // CHECK-SAME: -> !llvm.ptr
-  // CHECK: [[BINARYSIZE:%.*]] = llvm.mlir.constant
-  // CHECK: [[MODULE:%.*]] = llvm.call @mgpuModuleLoad([[BINARY]], [[BINARYSIZE]])
-  // CHECK: [[PARAMSCOUNT:%.*]] = llvm.mlir.constant
-  // CHECK: [[FUNC:%.*]] = llvm.call @mgpuModuleGetFunction([[MODULE]], {{.*}})
-
-  // CHECK: [[STREAM:%.*]] = llvm.call @mgpuStreamCreate
-
-  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32)
-  // CHECK: %[[MEMREF:.*]] = llvm.alloca %[[ONE]] x !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: [[NUM_PARAMS:%.*]] = llvm.mlir.constant(6 : i32) : i32
-  // CHECK-NEXT: [[PARAMS:%.*]] = llvm.alloca [[NUM_PARAMS]] x !llvm.ptr
-
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-
-  // CHECK: [[EXTRA_PARAMS:%.*]] = llvm.mlir.zero : !llvm.ptr
-
-  // CHECK: llvm.call @mgpuLaunchKernel([[FUNC]], [[C8]], [[C8]], [[C8]],
-  // CHECK-SAME: [[C8]], [[C8]], [[C8]], [[C256]], [[STREAM]],
-  // CHECK-SAME: [[PARAMS]], [[EXTRA_PARAMS]], [[PARAMSCOUNT]])
-  // CHECK: llvm.call @mgpuStreamSynchronize
-  // CHECK: llvm.call @mgpuStreamDestroy
-  // CHECK: llvm.call @mgpuModuleUnload
 }
 
+
 // -----
 
 module attributes {gpu.container_module} {
   // CHECK: gpu.module
-  // ROCDL: gpu.module
   gpu.module @kernel_module [#nvvm.target] {
     llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr,
         %arg2: !llvm.ptr, %arg3: i64, %arg4: i64,
@@ -80,15 +48,19 @@ module attributes {gpu.container_module} {
   // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
   // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
   // CHECK: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32
-    %c8 = arith.constant 8 : index
+  // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
+    %c8 = arith.constant 8 : index    
     %c32 = arith.constant 32 : i32
     %c256 = arith.constant 256 : i32
+    %c2 = arith.constant 2 : index
 
   // CHECK: gpu.launch_func @kernel_module::@kernel
+  // CHECK: clusters in ([[C2]], [[C2]], [[C2]])
   // CHECK: blocks in ([[C8]], [[C8]], [[C8]]) threads in ([[C8]], [[C8]], [[C8]]) : i64
   // CHECK: dynamic_shared_memory_size [[C256]]
   // CHECK: args([[C32]] : i32, %{{.*}} : !llvm.ptr, %{{.*}} : !llvm.ptr, %{{.*}} : i64, %{{.*}} : i64, %{{.*}} : i64)
     gpu.launch_func @kernel_module::@kernel
+        clusters in (%c2, %c2, %c2)
         blocks in (%c8, %c8, %c8)
         threads in (%c8, %c8, %c8)
         dynamic_shared_memory_size %c256
@@ -97,18 +69,11 @@ module attributes {gpu.container_module} {
   }
 }
 
-
 // -----
 
 module attributes {gpu.container_module} {
-  // CHECK: gpu.module
-  gpu.module @kernel_module [#nvvm.target] {
-    llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr,
-        %arg2: !llvm.ptr, %arg3: i64, %arg4: i64,
-        %arg5: i64) attributes {gpu.kernel} {
-      llvm.return
-    }
-  }
+  // CHECK: gpu.binary
+  gpu.binary @kernel_module [#gpu.object<#rocdl.target, "blob">]
 
   func.func @foo(%buffer: memref<?xf32>) {
   // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64



More information about the Mlir-commits mailing list