[Mlir-commits] [mlir] 8e12f31 - [mlir][gpu] Update LaunchFuncOp lowering in GPU to LLVM (#94991)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 10 18:22:26 PDT 2024


Author: Fabian Mora
Date: 2024-06-10T20:22:22-05:00
New Revision: 8e12f31be5a98a66700dd3571e4e12465f05ad61

URL: https://github.com/llvm/llvm-project/commit/8e12f31be5a98a66700dd3571e4e12465f05ad61
DIFF: https://github.com/llvm/llvm-project/commit/8e12f31be5a98a66700dd3571e4e12465f05ad61.diff

LOG: [mlir][gpu] Update LaunchFuncOp lowering in GPU to LLVM (#94991)

This patch updates the lowering of `LaunchFuncOp` in GPU to LLVM to only
legalize the operation with the converted operands, effectively removing
the lowering used by the old serialization pipeline.
It also removes all remaining uses of the old gpu serialization
infrastructure in `gpu-to-llvm`.

See [Compilation overview | 'gpu' Dialect - MLIR
docs](https://mlir.llvm.org/docs/Dialects/GPU/#compilation-overview) for
additional information on the target attributes compilation pipeline
that replaced the old serialization pipeline.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
index 48b7835ae5fca..2d5e9d27c5bdf 100644
--- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
@@ -46,9 +46,6 @@ class LLVMDialect;
 #define GEN_PASS_DECL_GPUTOLLVMCONVERSIONPASS
 #include "mlir/Conversion/Passes.h.inc"
 
-using OwnedBlob = std::unique_ptr<std::vector<char>>;
-using BlobGenerator =
-    std::function<OwnedBlob(const std::string &, Location, StringRef)>;
 using LoweringCallback = std::function<std::unique_ptr<llvm::Module>(
     Operation *, llvm::LLVMContext &, StringRef)>;
 
@@ -66,10 +63,9 @@ struct FunctionCallBuilder {
 
 /// Collect a set of patterns to convert from the GPU dialect to LLVM and
 /// populate converter for gpu types.
-void populateGpuToLLVMConversionPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    StringRef gpuBinaryAnnotation = {}, bool kernelBarePtrCallConv = false,
-    SymbolTable *cachedModuleTable = nullptr);
+void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                         RewritePatternSet &patterns,
+                                         bool kernelBarePtrCallConv = false);
 
 /// A function that maps a MemorySpace enum to a target-specific integer value.
 using MemorySpaceMapping = std::function<unsigned(gpu::AddressSpace)>;

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index eb58f4adc31d3..db67d6a5ff128 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -478,11 +478,7 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
            /*default=*/"false",
              "Use bare pointers to pass memref arguments to kernels. "
              "The kernel must use the same setting for this option."
-           >,
-    Option<"gpuBinaryAnnotation", "gpu-binary-annotation", "std::string",
-               /*default=*/"gpu::getDefaultGpuBinaryAnnotation()",
-               "Annotation attribute string for GPU binary"
-               >
+           >
   ];
 
   let dependentDialects = [

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 82bfa9514a884..92b28ff9c5873 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -49,8 +49,6 @@ namespace mlir {
 
 using namespace mlir;
 
-static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
-
 namespace {
 class GpuToLLVMConversionPass
     : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
@@ -97,36 +95,6 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   Type llvmIntPtrType = IntegerType::get(
       context, this->getTypeConverter()->getPointerBitwidth(0));
 
-  FunctionCallBuilder moduleLoadCallBuilder = {
-      "mgpuModuleLoad",
-      llvmPointerType /* void *module */,
-      {llvmPointerType /* void *cubin */, llvmInt64Type /* size_t size */}};
-  FunctionCallBuilder moduleUnloadCallBuilder = {
-      "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
-  FunctionCallBuilder moduleGetFunctionCallBuilder = {
-      "mgpuModuleGetFunction",
-      llvmPointerType /* void *function */,
-      {
-          llvmPointerType, /* void *module */
-          llvmPointerType  /* char *name   */
-      }};
-  FunctionCallBuilder launchKernelCallBuilder = {
-      "mgpuLaunchKernel",
-      llvmVoidType,
-      {
-          llvmPointerType, /* void* f */
-          llvmIntPtrType,  /* intptr_t gridXDim */
-          llvmIntPtrType,  /* intptr_t gridyDim */
-          llvmIntPtrType,  /* intptr_t gridZDim */
-          llvmIntPtrType,  /* intptr_t blockXDim */
-          llvmIntPtrType,  /* intptr_t blockYDim */
-          llvmIntPtrType,  /* intptr_t blockZDim */
-          llvmInt32Type,   /* unsigned int sharedMemBytes */
-          llvmPointerType, /* void *hstream */
-          llvmPointerType, /* void **kernelParams */
-          llvmPointerType, /* void **extra */
-          llvmInt64Type    /* size_t paramsCount */
-      }};
   FunctionCallBuilder streamCreateCallBuilder = {
       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
   FunctionCallBuilder streamDestroyCallBuilder = {
@@ -451,55 +419,21 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// A rewrite patter to convert gpu.launch_func operations into a sequence of
-/// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
-///
-/// In essence, a gpu.launch_func operations gets compiled into the following
-/// sequence of runtime calls:
-///
-/// * moduleLoad        -- loads the module given the cubin / hsaco data
-/// * moduleGetFunction -- gets a handle to the actual kernel function
-/// * getStreamHelper   -- initializes a new compute stream on GPU
-/// * launchKernel      -- launches the kernel on a stream
-/// * streamSynchronize -- waits for operations on the stream to finish
-///
-/// Intermediate data structures are allocated on the stack.
-class ConvertLaunchFuncOpToGpuRuntimeCallPattern
+/// A rewrite patter to legalize gpu.launch_func with LLVM types.
+class LegalizeLaunchFuncOpPattern
     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
 public:
-  ConvertLaunchFuncOpToGpuRuntimeCallPattern(
-      const LLVMTypeConverter &typeConverter, StringRef gpuBinaryAnnotation,
-      bool kernelBarePtrCallConv, SymbolTable *cachedModuleTable)
+  LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
+                              bool kernelBarePtrCallConv)
       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
-        gpuBinaryAnnotation(gpuBinaryAnnotation),
-        kernelBarePtrCallConv(kernelBarePtrCallConv),
-        cachedModuleTable(cachedModuleTable) {}
+        kernelBarePtrCallConv(kernelBarePtrCallConv) {}
 
 private:
-  Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
-                            OpBuilder &builder) const;
-  Value generateKernelNameConstant(StringRef moduleName, StringRef name,
-                                   Location loc, OpBuilder &builder) const;
-
   LogicalResult
   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 
-  llvm::SmallString<32> gpuBinaryAnnotation;
   bool kernelBarePtrCallConv;
-  SymbolTable *cachedModuleTable;
-};
-
-class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
-  using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
-                                PatternRewriter &rewriter) const override {
-    // GPU kernel modules are no longer necessary since we have a global
-    // constant with the CUBIN, or HSACO data.
-    rewriter.eraseOp(op);
-    return success();
-  }
 };
 
 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
@@ -587,7 +521,6 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
 
 void GpuToLLVMConversionPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  SymbolTable symbolTable = SymbolTable(getOperation());
   LowerToLLVMOptions options(context);
   options.useBarePtrCallConv = hostBarePtrCallConv;
   RewritePatternSet patterns(context);
@@ -604,30 +537,20 @@ void GpuToLLVMConversionPass::runOnOperation() {
     iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
   }
 
-  // Preserve GPU modules if they have target attributes.
-  target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
-      [](gpu::GPUModuleOp module) -> bool {
-        return module.getTargetsAttr() != nullptr;
-      });
-  // Accept as legal LaunchFuncOps if they refer to GPU Modules with targets and
-  // the operands have been lowered.
+  // Preserve GPU modules and binaries. Modules are preserved as they can be
+  // converted later by `gpu-module-to-binary`.
+  target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
+  // Accept as legal LaunchFuncOps if the operands have been lowered.
   target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
-      [&](gpu::LaunchFuncOp op) -> bool {
-        auto module =
-            symbolTable.lookup<gpu::GPUModuleOp>(op.getKernelModuleName());
-        return converter.isLegal(op->getOperandTypes()) &&
-               converter.isLegal(op->getResultTypes()) &&
-               (module && module.getTargetsAttr() &&
-                !module.getTargetsAttr().empty());
-      });
+      [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
 
   // These aren't covered by the ConvertToLLVMPatternInterface right now.
   populateVectorToLLVMConversionPatterns(converter, patterns);
   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
                                                     target);
-  populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
-                                      kernelBarePtrCallConv, &symbolTable);
+  populateGpuToLLVMConversionPatterns(converter, patterns,
+                                      kernelBarePtrCallConv);
 
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
@@ -1002,100 +925,8 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
-// Creates a struct containing all kernel parameters on the stack and returns
-// an array of type-erased pointers to the fields of the struct. The array can
-// then be passed to the CUDA / ROCm (HIP) kernel launch calls.
-// The generated code is essentially as follows:
-//
-// %struct = alloca(sizeof(struct { Parameters... }))
-// %array = alloca(NumParameters * sizeof(void *))
-// for (i : [0, NumParameters))
-//   %fieldPtr = llvm.getelementptr %struct[0, i]
-//   llvm.store parameters[i], %fieldPtr
-//   %elementPtr = llvm.getelementptr %array[i]
-//   llvm.store %fieldPtr, %elementPtr
-// return %array
-Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
-    gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
-  auto loc = launchOp.getLoc();
-  auto numKernelOperands = launchOp.getNumKernelOperands();
-  // Note: If `useBarePtrCallConv` is set in the type converter's options,
-  // the value of `kernelBarePtrCallConv` will be ignored.
-  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
-      loc, launchOp.getOperands().take_back(numKernelOperands),
-      adaptor.getOperands().take_back(numKernelOperands), builder,
-      /*useBarePtrCallConv=*/kernelBarePtrCallConv);
-  auto numArguments = arguments.size();
-  SmallVector<Type, 4> argumentTypes;
-  argumentTypes.reserve(numArguments);
-  for (auto argument : arguments)
-    argumentTypes.push_back(argument.getType());
-  auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
-                                                           argumentTypes);
-  auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 1);
-  auto structPtr =
-      builder.create<LLVM::AllocaOp>(loc, llvmPointerType, structType, one,
-                                     /*alignment=*/0);
-  auto arraySize =
-      builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, numArguments);
-  auto arrayPtr = builder.create<LLVM::AllocaOp>(
-      loc, llvmPointerType, llvmPointerType, arraySize, /*alignment=*/0);
-  for (const auto &en : llvm::enumerate(arguments)) {
-    const auto index = static_cast<int32_t>(en.index());
-    Value fieldPtr =
-        builder.create<LLVM::GEPOp>(loc, llvmPointerType, structType, structPtr,
-                                    ArrayRef<LLVM::GEPArg>{0, index});
-    builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
-    auto elementPtr =
-        builder.create<LLVM::GEPOp>(loc, llvmPointerType, llvmPointerType,
-                                    arrayPtr, ArrayRef<LLVM::GEPArg>{index});
-    builder.create<LLVM::StoreOp>(loc, fieldPtr, elementPtr);
-  }
-  return arrayPtr;
-}
-
-// Generates an LLVM IR dialect global that contains the name of the given
-// kernel function as a C string, and returns a pointer to its beginning.
-// The code is essentially:
-//
-// llvm.global constant @kernel_name("function_name\00")
-// func(...) {
-//   %0 = llvm.addressof @kernel_name
-//   %1 = llvm.constant (0 : index)
-//   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
-// }
-Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
-    StringRef moduleName, StringRef name, Location loc,
-    OpBuilder &builder) const {
-  // Make sure the trailing zero is included in the constant.
-  std::vector<char> kernelName(name.begin(), name.end());
-  kernelName.push_back('\0');
-
-  std::string globalName =
-      std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
-  return LLVM::createGlobalString(
-      loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
-      LLVM::Linkage::Internal);
-}
-
-// Emits LLVM IR to launch a kernel function. Expects the module that contains
-// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
-// hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
-//
-// %0 = call %binarygetter
-// %1 = call %moduleLoad(%0)
-// %2 = <see generateKernelNameConstant>
-// %3 = call %moduleGetFunction(%1, %2)
-// %4 = call %streamCreate()
-// %5 = <see generateParamsArray>
-// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
-// call %streamSynchronize(%4)
-// call %streamDestroy(%4)
-// call %moduleUnload(%1)
-//
-// If the op is async, the stream corresponds to the (single) async dependency
-// as well as the async token the op produces.
-LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
+// Legalize the op's operands.
+LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
@@ -1114,123 +945,37 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
 
   Location loc = launchOp.getLoc();
 
-  // Create an LLVM global with CUBIN extracted from the kernel annotation and
-  // obtain a pointer to the first byte in it.
-  gpu::GPUModuleOp kernelModule;
-  if (cachedModuleTable)
-    kernelModule = cachedModuleTable->lookup<gpu::GPUModuleOp>(
-        launchOp.getKernelModuleName());
-  else
-    kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
-        launchOp, launchOp.getKernelModuleName());
-  assert(kernelModule && "expected a kernel module");
-
-  // If the module has Targets then just update the op operands.
-  if (ArrayAttr targets = kernelModule.getTargetsAttr()) {
-    Value stream = Value();
-    if (!adaptor.getAsyncDependencies().empty())
-      stream = adaptor.getAsyncDependencies().front();
-    // If the async keyword is present and there are no dependencies, then a
-    // stream must be created to pass to subsequent operations.
-    else if (launchOp.getAsyncToken())
-      stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
-
-    // Lower the kernel operands to match kernel parameters.
-    // Note: If `useBarePtrCallConv` is set in the type converter's options,
-    // the value of `kernelBarePtrCallConv` will be ignored.
-    SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
-        loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
-        rewriter, /*useBarePtrCallConv=*/kernelBarePtrCallConv);
-
-    std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
-    if (launchOp.hasClusterSize()) {
-      clusterSize =
-          gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
-                          adaptor.getClusterSizeZ()};
-    }
-    rewriter.create<gpu::LaunchFuncOp>(
-        launchOp.getLoc(), launchOp.getKernelAttr(),
-        gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
-                        adaptor.getGridSizeZ()},
-        gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
-                        adaptor.getBlockSizeZ()},
-        adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
-    if (launchOp.getAsyncToken())
-      rewriter.replaceOp(launchOp, {stream});
-    else
-      rewriter.eraseOp(launchOp);
-    return success();
-  }
+  Value stream = Value();
+  if (!adaptor.getAsyncDependencies().empty())
+    stream = adaptor.getAsyncDependencies().front();
+  // If the async keyword is present and there are no dependencies, then a
+  // stream must be created to pass to subsequent operations.
+  else if (launchOp.getAsyncToken())
+    stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
+  // Lower the kernel operands to match kernel parameters.
+  // Note: If `useBarePtrCallConv` is set in the type converter's options,
+  // the value of `kernelBarePtrCallConv` will be ignored.
+  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
+      loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
+      /*useBarePtrCallConv=*/kernelBarePtrCallConv);
 
-  auto binaryAttr =
-      kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
-  if (!binaryAttr) {
-    kernelModule.emitOpError()
-        << "missing " << gpuBinaryAnnotation << " attribute";
-    return failure();
+  std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
+  if (launchOp.hasClusterSize()) {
+    clusterSize =
+        gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
+                        adaptor.getClusterSizeZ()};
   }
-
-  SmallString<128> nameBuffer(kernelModule.getName());
-  nameBuffer.append(kGpuBinaryStorageSuffix);
-  Value data =
-      LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
-                               binaryAttr.getValue(), LLVM::Linkage::Internal);
-
-  // Pass the binary size. SPIRV requires binary size.
-  auto gpuBlob = binaryAttr.getValue();
-  auto gpuBlobSize = rewriter.create<mlir::LLVM::ConstantOp>(
-      loc, llvmInt64Type,
-      mlir::IntegerAttr::get(llvmInt64Type,
-                             static_cast<int64_t>(gpuBlob.size())));
-
-  auto module =
-      moduleLoadCallBuilder.create(loc, rewriter, {data, gpuBlobSize});
-
-  // Pass the count of the parameters to runtime wrappers
-  auto paramsCount = rewriter.create<mlir::LLVM::ConstantOp>(
-      loc, llvmInt64Type,
-      mlir::IntegerAttr::get(
-          llvmInt64Type,
-          static_cast<int64_t>(launchOp.getNumKernelOperands())));
-
-  // Get the function from the module. The name corresponds to the name of
-  // the kernel function.
-  auto kernelName = generateKernelNameConstant(
-      launchOp.getKernelModuleName().getValue(),
-      launchOp.getKernelName().getValue(), loc, rewriter);
-  auto function = moduleGetFunctionCallBuilder.create(
-      loc, rewriter, {module.getResult(), kernelName});
-  Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
-  Value stream =
-      adaptor.getAsyncDependencies().empty()
-          ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
-          : adaptor.getAsyncDependencies().front();
-  // Create array of pointers to kernel arguments.
-  auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
-  auto nullpointer = rewriter.create<LLVM::ZeroOp>(loc, llvmPointerType);
-  Value dynamicSharedMemorySize = launchOp.getDynamicSharedMemorySize()
-                                      ? launchOp.getDynamicSharedMemorySize()
-                                      : zero;
-  launchKernelCallBuilder.create(
-      loc, rewriter,
-      {function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(),
-       adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
-       adaptor.getBlockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
-       /*extra=*/nullpointer, paramsCount});
-
-  if (launchOp.getAsyncToken()) {
-    // Async launch: make dependent ops use the same stream.
+  rewriter.create<gpu::LaunchFuncOp>(
+      launchOp.getLoc(), launchOp.getKernelAttr(),
+      gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
+                      adaptor.getGridSizeZ()},
+      gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
+                      adaptor.getBlockSizeZ()},
+      adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
+  if (launchOp.getAsyncToken())
     rewriter.replaceOp(launchOp, {stream});
-  } else {
-    // Synchronize with host and destroy stream. This must be the stream created
-    // above (with no other uses) because we check that the synchronous version
-    // does not have any async dependencies.
-    streamSynchronizeCallBuilder.create(loc, rewriter, stream);
-    streamDestroyCallBuilder.create(loc, rewriter, stream);
+  else
     rewriter.eraseOp(launchOp);
-  }
-  moduleUnloadCallBuilder.create(loc, rewriter, module.getResult());
-
   return success();
 }
 
@@ -1978,9 +1723,7 @@ LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
 
 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns,
-                                               StringRef gpuBinaryAnnotation,
-                                               bool kernelBarePtrCallConv,
-                                               SymbolTable *cachedModuleTable) {
+                                               bool kernelBarePtrCallConv) {
   addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
   addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
   addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
@@ -2017,7 +1760,5 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
                ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
                ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
-  patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
-      converter, gpuBinaryAnnotation, kernelBarePtrCallConv, cachedModuleTable);
-  patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
+  patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
 }

diff  --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
index c0b05ef086033..6c5c1e09c0eb5 100644
--- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
@@ -1,15 +1,8 @@
-// RUN: mlir-opt %s --gpu-to-llvm="gpu-binary-annotation=nvvm.cubin" -split-input-file | FileCheck %s
-// RUN: mlir-opt %s --gpu-to-llvm="gpu-binary-annotation=rocdl.hsaco" -split-input-file | FileCheck %s --check-prefix=ROCDL
+// RUN: mlir-opt %s --gpu-to-llvm -split-input-file | FileCheck %s
 
 module attributes {gpu.container_module} {
-
-  // CHECK: llvm.mlir.global internal constant @[[KERNEL_NAME:.*]]("kernel\00")
-  // CHECK: llvm.mlir.global internal constant @[[GLOBAL:.*]]("CUBIN")
-  // ROCDL: llvm.mlir.global internal constant @[[GLOBAL:.*]]("HSACO")
-
-  gpu.module @kernel_module attributes {
-      nvvm.cubin = "CUBIN", rocdl.hsaco = "HSACO"
-  } {
+  // CHECK: gpu.module
+  gpu.module @kernel_module [#nvvm.target] {
     llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr,
         %arg2: !llvm.ptr, %arg3: i64, %arg4: i64,
         %arg5: i64) attributes {gpu.kernel} {
@@ -18,9 +11,17 @@ module attributes {gpu.container_module} {
   }
 
   func.func @foo(%buffer: memref<?xf32>) {
+  // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
+  // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
+  // CHECK: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32
     %c8 = arith.constant 8 : index
     %c32 = arith.constant 32 : i32
     %c256 = arith.constant 256 : i32
+
+  // CHECK: gpu.launch_func @kernel_module::@kernel
+  // CHECK: blocks in ([[C8]], [[C8]], [[C8]]) threads in ([[C8]], [[C8]], [[C8]]) : i64
+  // CHECK: dynamic_shared_memory_size [[C256]]
+  // CHECK: args([[C32]] : i32, %{{.*}} : !llvm.ptr, %{{.*}} : !llvm.ptr, %{{.*}} : i64, %{{.*}} : i64, %{{.*}} : i64)
     gpu.launch_func @kernel_module::@kernel
         blocks in (%c8, %c8, %c8)
         threads in (%c8, %c8, %c8)
@@ -28,46 +29,13 @@ module attributes {gpu.container_module} {
         args(%c32 : i32, %buffer : memref<?xf32>)
     return
   }
-
-  // CHECK-DAG: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32
-  // CHECK-DAG: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
-  // CHECK: [[ADDRESSOF:%.*]] = llvm.mlir.addressof @[[GLOBAL]]
-  // CHECK: [[BINARY:%.*]] = llvm.getelementptr [[ADDRESSOF]]{{\[}}0, 0]
-  // CHECK-SAME: -> !llvm.ptr
-  // CHECK: [[BINARYSIZE:%.*]] = llvm.mlir.constant
-  // CHECK: [[MODULE:%.*]] = llvm.call @mgpuModuleLoad([[BINARY]], [[BINARYSIZE]])
-  // CHECK: [[PARAMSCOUNT:%.*]] = llvm.mlir.constant
-  // CHECK: [[FUNC:%.*]] = llvm.call @mgpuModuleGetFunction([[MODULE]], {{.*}})
-
-  // CHECK: [[STREAM:%.*]] = llvm.call @mgpuStreamCreate
-
-  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32)
-  // CHECK: %[[MEMREF:.*]] = llvm.alloca %[[ONE]] x !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: [[NUM_PARAMS:%.*]] = llvm.mlir.constant(6 : i32) : i32
-  // CHECK-NEXT: [[PARAMS:%.*]] = llvm.alloca [[NUM_PARAMS]] x !llvm.ptr
-
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-  // CHECK: llvm.getelementptr %[[MEMREF]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct[[STRUCT_BODY:<.*>]]
-
-  // CHECK: [[EXTRA_PARAMS:%.*]] = llvm.mlir.zero : !llvm.ptr
-
-  // CHECK: llvm.call @mgpuLaunchKernel([[FUNC]], [[C8]], [[C8]], [[C8]],
-  // CHECK-SAME: [[C8]], [[C8]], [[C8]], [[C256]], [[STREAM]],
-  // CHECK-SAME: [[PARAMS]], [[EXTRA_PARAMS]], [[PARAMSCOUNT]])
-  // CHECK: llvm.call @mgpuStreamSynchronize
-  // CHECK: llvm.call @mgpuStreamDestroy
-  // CHECK: llvm.call @mgpuModuleUnload
 }
 
+
 // -----
 
 module attributes {gpu.container_module} {
   // CHECK: gpu.module
-  // ROCDL: gpu.module
   gpu.module @kernel_module [#nvvm.target] {
     llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr,
         %arg2: !llvm.ptr, %arg3: i64, %arg4: i64,
@@ -80,15 +48,19 @@ module attributes {gpu.container_module} {
   // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
   // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
   // CHECK: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32
-    %c8 = arith.constant 8 : index
+  // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
+    %c8 = arith.constant 8 : index    
     %c32 = arith.constant 32 : i32
     %c256 = arith.constant 256 : i32
+    %c2 = arith.constant 2 : index
 
   // CHECK: gpu.launch_func @kernel_module::@kernel
+  // CHECK: clusters in ([[C2]], [[C2]], [[C2]])
   // CHECK: blocks in ([[C8]], [[C8]], [[C8]]) threads in ([[C8]], [[C8]], [[C8]]) : i64
   // CHECK: dynamic_shared_memory_size [[C256]]
   // CHECK: args([[C32]] : i32, %{{.*}} : !llvm.ptr, %{{.*}} : !llvm.ptr, %{{.*}} : i64, %{{.*}} : i64, %{{.*}} : i64)
     gpu.launch_func @kernel_module::@kernel
+        clusters in (%c2, %c2, %c2)
         blocks in (%c8, %c8, %c8)
         threads in (%c8, %c8, %c8)
         dynamic_shared_memory_size %c256
@@ -97,18 +69,11 @@ module attributes {gpu.container_module} {
   }
 }
 
-
 // -----
 
 module attributes {gpu.container_module} {
-  // CHECK: gpu.module
-  gpu.module @kernel_module [#nvvm.target] {
-    llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr,
-        %arg2: !llvm.ptr, %arg3: i64, %arg4: i64,
-        %arg5: i64) attributes {gpu.kernel} {
-      llvm.return
-    }
-  }
+  // CHECK: gpu.binary
+  gpu.binary @kernel_module [#gpu.object<#rocdl.target, "blob">]
 
   func.func @foo(%buffer: memref<?xf32>) {
   // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64


        


More information about the Mlir-commits mailing list