[Mlir-commits] [mlir] [mlir][gpu] Remove old GPU serialization passes (PR #94998)

Fabian Mora llvmlistbot at llvm.org
Thu Jun 13 08:45:58 PDT 2024


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/94998

>From 2fa5f7a5450b34ced823faf92386f2dfd68aa0bc Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 10 Jun 2024 14:58:38 +0000
Subject: [PATCH 1/2] [mlir][gpu] Remove old GPU serialization passes

This patch removes the last vestiges of the old gpu serialization
pipeline. To compile GPU code use target attributes instead.
---
 mlir/include/mlir/Conversion/Passes.td        |   2 +-
 .../mlir/Dialect/GPU/Transforms/Passes.h      |  63 ---
 .../mlir/Dialect/GPU/Transforms/Utils.h       |   3 -
 mlir/lib/Dialect/GPU/CMakeLists.txt           |  43 --
 .../GPU/Transforms/SerializeToBlob.cpp        | 153 ------
 .../GPU/Transforms/SerializeToHsaco.cpp       | 458 ------------------
 6 files changed, 1 insertion(+), 721 deletions(-)
 delete mode 100644 mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
 delete mode 100644 mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index db67d6a5ff128..71edd1b5f27c3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -478,7 +478,7 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
            /*default=*/"false",
              "Use bare pointers to pass memref arguments to kernels. "
              "The kernel must use the same setting for this option."
-           >
+          >
   ];
 
   let dependentDialects = [
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 8f7466a697d85..a20bae86ace28 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -91,75 +91,12 @@ namespace gpu {
 LogicalResult transformGpuModulesToBinaries(
     Operation *op, OffloadingLLVMTranslationAttrInterface handler = nullptr,
     const gpu::TargetOptions &options = {});
-
-/// Base pass class to serialize kernel functions through LLVM into
-/// user-specified IR and add the resulting blob as module attribute.
-class SerializeToBlobPass : public OperationPass<gpu::GPUModuleOp> {
-public:
-  SerializeToBlobPass(TypeID passID);
-  SerializeToBlobPass(const SerializeToBlobPass &other);
-
-  void runOnOperation() final;
-
-protected:
-  /// Hook allowing the application of optimizations before codegen
-  /// By default, does nothing
-  virtual LogicalResult optimizeLlvm(llvm::Module &llvmModule,
-                                     llvm::TargetMachine &targetMachine);
-
-  /// Translates the 'getOperation()' result to an LLVM module.
-  virtual std::unique_ptr<llvm::Module>
-  translateToLLVMIR(llvm::LLVMContext &llvmContext);
-
-private:
-  /// Creates the LLVM target machine to generate the ISA.
-  std::unique_ptr<llvm::TargetMachine> createTargetMachine();
-
-  /// Translates the module to ISA
-  std::optional<std::string> translateToISA(llvm::Module &llvmModule,
-                                            llvm::TargetMachine &targetMachine);
-
-  /// Serializes the target ISA to binary form.
-  virtual std::unique_ptr<std::vector<char>>
-  serializeISA(const std::string &isa) = 0;
-
-protected:
-  Option<std::string> triple{*this, "triple",
-                             ::llvm::cl::desc("Target triple")};
-  Option<std::string> chip{*this, "chip",
-                           ::llvm::cl::desc("Target architecture")};
-  Option<std::string> features{*this, "features",
-                               ::llvm::cl::desc("Target features")};
-  Option<int> optLevel{*this, "opt-level",
-                       llvm::cl::desc("Optimization level for compilation"),
-                       llvm::cl::init(2)};
-  Option<std::string> gpuBinaryAnnotation{
-      *this, "gpu-binary-annotation",
-      llvm::cl::desc("Annotation attribute string for GPU binary"),
-      llvm::cl::init(getDefaultGpuBinaryAnnotation())};
-  Option<bool> dumpPtx{*this, "dump-ptx",
-                       ::llvm::cl::desc("Dump generated PTX"),
-                       llvm::cl::init(false)};
-};
 } // namespace gpu
 
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
 
-/// Register pass to serialize GPU kernel functions to a HSAco binary
-/// annotation.
-LLVM_DEPRECATED("use Target attributes instead", "")
-void registerGpuSerializeToHsacoPass();
-
-/// Create an instance of the GPU kernel function to HSAco binary serialization
-/// pass.
-LLVM_DEPRECATED("use Target attributes instead", "")
-std::unique_ptr<Pass> createGpuSerializeToHsacoPass(StringRef triple,
-                                                    StringRef arch,
-                                                    StringRef features,
-                                                    int optLevel);
-
 /// Collect a set of patterns to decompose memrefs ops.
 void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
index f25c506fd638d..f8c018ef40bba 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
@@ -28,9 +28,6 @@ namespace gpu {
 class GPUFuncOp;
 class LaunchOp;
 
-/// Returns the default annotation name for GPU binary blobs.
-std::string getDefaultGpuBinaryAnnotation();
-
 /// Returns the matching vector combining kind.
 vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode);
 } // namespace gpu
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 61ab298ebfb98..1934744c47fc9 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -1,17 +1,3 @@
-if (MLIR_ENABLE_ROCM_CONVERSIONS)
-  set(AMDGPU_LIBS
-    IRReader
-    IPO
-    linker
-    MCParser
-    AMDGPUAsmParser
-    AMDGPUCodeGen
-    AMDGPUDesc
-    AMDGPUInfo
-    target
-  )
-endif()
-
 add_mlir_dialect_library(MLIRGPUDialect
   IR/GPUDialect.cpp
   IR/InferIntRangeInterfaceImpls.cpp
@@ -51,8 +37,6 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/NVVMAttachTarget.cpp
   Transforms/ParallelLoopMapper.cpp
   Transforms/ROCDLAttachTarget.cpp
-  Transforms/SerializeToBlob.cpp
-  Transforms/SerializeToHsaco.cpp
   Transforms/ShuffleRewriter.cpp
   Transforms/SPIRVAttachTarget.cpp
   Transforms/SubgroupReduceLowering.cpp
@@ -61,12 +45,6 @@ add_mlir_dialect_library(MLIRGPUTransforms
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
 
-  LINK_COMPONENTS
-  Core
-  MC
-  Target
-  ${AMDGPU_LIBS}
-
   DEPENDS
   MLIRGPUPassIncGen
   MLIRParallelLoopMapperEnumsGen
@@ -76,15 +54,12 @@ add_mlir_dialect_library(MLIRGPUTransforms
   MLIRArithDialect
   MLIRAsyncDialect
   MLIRBufferizationDialect
-  MLIRBuiltinToLLVMIRTranslation
   MLIRDataLayoutInterfaces
   MLIRExecutionEngineUtils
   MLIRGPUDialect
   MLIRIR
   MLIRIndexDialect
   MLIRLLVMDialect
-  MLIRGPUToLLVMIRTranslation
-  MLIRLLVMToLLVMIRTranslation
   MLIRMemRefDialect
   MLIRNVVMTarget
   MLIRPass
@@ -99,21 +74,3 @@ add_mlir_dialect_library(MLIRGPUTransforms
 
 add_subdirectory(TransformOps)
 add_subdirectory(Pipelines)
-
-if(MLIR_ENABLE_ROCM_CONVERSIONS)
-  if (NOT ("AMDGPU" IN_LIST LLVM_TARGETS_TO_BUILD))
-    message(SEND_ERROR
-      "Building mlir with ROCm support requires the AMDGPU backend")
-  endif()
-
-  set(DEFAULT_ROCM_PATH "/opt/rocm" CACHE PATH "Fallback path to search for ROCm installs")
-  target_compile_definitions(obj.MLIRGPUTransforms
-    PRIVATE
-    __DEFAULT_ROCM_PATH__="${DEFAULT_ROCM_PATH}"
-  )
-
-  target_link_libraries(MLIRGPUTransforms
-    PRIVATE
-    MLIRROCDLToLLVMIRTranslation
-  )
-endif()
diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
deleted file mode 100644
index 1fdfe972a8b59..0000000000000
--- a/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
+++ /dev/null
@@ -1,153 +0,0 @@
-//===- SerializeToBlob.cpp - MLIR GPU lowering pass -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a base class for a pass to serialize a gpu module
-// into a binary blob that can be executed on a GPU. The binary blob is added
-// as a string attribute to the gpu module.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/GPU/Transforms/Passes.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/ExecutionEngine/OptUtils.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Export.h"
-#include "llvm/IR/LegacyPassManager.h"
-#include "llvm/MC/TargetRegistry.h"
-#include "llvm/Support/TargetSelect.h"
-#include "llvm/Target/TargetMachine.h"
-
-#include <optional>
-#include <string>
-
-#define DEBUG_TYPE "serialize-to-blob"
-
-using namespace mlir;
-
-std::string gpu::getDefaultGpuBinaryAnnotation() { return "gpu.binary"; }
-
-gpu::SerializeToBlobPass::SerializeToBlobPass(TypeID passID)
-    : OperationPass<gpu::GPUModuleOp>(passID) {}
-
-gpu::SerializeToBlobPass::SerializeToBlobPass(const SerializeToBlobPass &other)
-    : OperationPass<gpu::GPUModuleOp>(other) {}
-
-std::optional<std::string>
-gpu::SerializeToBlobPass::translateToISA(llvm::Module &llvmModule,
-                                         llvm::TargetMachine &targetMachine) {
-  llvmModule.setDataLayout(targetMachine.createDataLayout());
-
-  if (failed(optimizeLlvm(llvmModule, targetMachine)))
-    return std::nullopt;
-
-  std::string targetISA;
-  llvm::raw_string_ostream stream(targetISA);
-
-  { // Drop pstream after this to prevent the ISA from being stuck buffering
-    llvm::buffer_ostream pstream(stream);
-    llvm::legacy::PassManager codegenPasses;
-
-    if (targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
-                                          llvm::CodeGenFileType::AssemblyFile))
-      return std::nullopt;
-
-    codegenPasses.run(llvmModule);
-  }
-  return stream.str();
-}
-
-void gpu::SerializeToBlobPass::runOnOperation() {
-  // Lower the module to an LLVM IR module using a separate context to enable
-  // multi-threaded processing.
-  llvm::LLVMContext llvmContext;
-  std::unique_ptr<llvm::Module> llvmModule = translateToLLVMIR(llvmContext);
-  if (!llvmModule)
-    return signalPassFailure();
-
-  // Lower the LLVM IR module to target ISA.
-  std::unique_ptr<llvm::TargetMachine> targetMachine = createTargetMachine();
-  if (!targetMachine)
-    return signalPassFailure();
-
-  std::optional<std::string> maybeTargetISA =
-      translateToISA(*llvmModule, *targetMachine);
-
-  if (!maybeTargetISA.has_value())
-    return signalPassFailure();
-
-  std::string targetISA = std::move(*maybeTargetISA);
-
-  LLVM_DEBUG({
-    llvm::dbgs() << "ISA for module: " << getOperation().getNameAttr() << "\n";
-    llvm::dbgs() << targetISA << "\n";
-    llvm::dbgs().flush();
-  });
-
-  // Serialize the target ISA.
-  std::unique_ptr<std::vector<char>> blob = serializeISA(targetISA);
-  if (!blob)
-    return signalPassFailure();
-
-  // Add the blob as module attribute.
-  auto attr =
-      StringAttr::get(&getContext(), StringRef(blob->data(), blob->size()));
-  getOperation()->setAttr(gpuBinaryAnnotation, attr);
-}
-
-LogicalResult
-gpu::SerializeToBlobPass::optimizeLlvm(llvm::Module &llvmModule,
-                                       llvm::TargetMachine &targetMachine) {
-  int optLevel = this->optLevel.getValue();
-  if (optLevel < 0 || optLevel > 3)
-    return getOperation().emitError()
-           << "invalid optimization level " << optLevel;
-
-  targetMachine.setOptLevel(static_cast<llvm::CodeGenOptLevel>(optLevel));
-
-  auto transformer =
-      makeOptimizingTransformer(optLevel, /*sizeLevel=*/0, &targetMachine);
-  auto error = transformer(&llvmModule);
-  if (error) {
-    InFlightDiagnostic mlirError = getOperation()->emitError();
-    llvm::handleAllErrors(
-        std::move(error), [&mlirError](const llvm::ErrorInfoBase &ei) {
-          mlirError << "could not optimize LLVM IR: " << ei.message();
-        });
-    return mlirError;
-  }
-  return success();
-}
-
-std::unique_ptr<llvm::TargetMachine>
-gpu::SerializeToBlobPass::createTargetMachine() {
-  Location loc = getOperation().getLoc();
-  std::string error;
-  const llvm::Target *target =
-      llvm::TargetRegistry::lookupTarget(triple, error);
-  if (!target) {
-    emitError(loc, Twine("failed to lookup target: ") + error);
-    return {};
-  }
-  llvm::TargetMachine *machine =
-      target->createTargetMachine(triple, chip, features, {}, {});
-  if (!machine) {
-    emitError(loc, "failed to create target machine");
-    return {};
-  }
-
-  return std::unique_ptr<llvm::TargetMachine>{machine};
-}
-
-std::unique_ptr<llvm::Module>
-gpu::SerializeToBlobPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
-  return translateModuleToLLVMIR(getOperation(), llvmContext,
-                                 "LLVMDialectModule");
-}
diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
deleted file mode 100644
index a4f19981eec38..0000000000000
--- a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
+++ /dev/null
@@ -1,458 +0,0 @@
-//===- LowerGPUToHSACO.cpp - Convert GPU kernel to HSACO blob -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a pass that serializes a gpu module into HSAco blob and
-// adds that blob as a string attribute of the module.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Config/mlir-config.h"
-#include "mlir/Dialect/GPU/Transforms/Passes.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-
-#if MLIR_ENABLE_ROCM_CONVERSIONS
-#include "mlir/ExecutionEngine/OptUtils.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/FileUtilities.h"
-#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Export.h"
-
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/GlobalVariable.h"
-#include "llvm/IR/Module.h"
-#include "llvm/IRReader/IRReader.h"
-#include "llvm/Linker/Linker.h"
-
-#include "llvm/MC/MCAsmBackend.h"
-#include "llvm/MC/MCAsmInfo.h"
-#include "llvm/MC/MCCodeEmitter.h"
-#include "llvm/MC/MCContext.h"
-#include "llvm/MC/MCInstrInfo.h"
-#include "llvm/MC/MCObjectFileInfo.h"
-#include "llvm/MC/MCObjectWriter.h"
-#include "llvm/MC/MCParser/MCTargetAsmParser.h"
-#include "llvm/MC/MCRegisterInfo.h"
-#include "llvm/MC/MCStreamer.h"
-#include "llvm/MC/MCSubtargetInfo.h"
-#include "llvm/MC/TargetRegistry.h"
-
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/FileSystem.h"
-#include "llvm/Support/FileUtilities.h"
-#include "llvm/Support/Path.h"
-#include "llvm/Support/Program.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/TargetSelect.h"
-#include "llvm/Support/Threading.h"
-#include "llvm/Support/WithColor.h"
-
-#include "llvm/Target/TargetMachine.h"
-#include "llvm/Target/TargetOptions.h"
-
-#include "llvm/Transforms/IPO/Internalize.h"
-
-#include <optional>
-
-using namespace mlir;
-
-namespace {
-class SerializeToHsacoPass
-    : public PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass> {
-  static llvm::once_flag initializeBackendOnce;
-
-public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SerializeToHsacoPass)
-
-  SerializeToHsacoPass(StringRef triple, StringRef arch, StringRef features,
-                       int optLevel);
-  SerializeToHsacoPass(const SerializeToHsacoPass &other);
-  StringRef getArgument() const override { return "gpu-to-hsaco"; }
-  StringRef getDescription() const override {
-    return "Lower GPU kernel function to HSACO binary annotations";
-  }
-
-protected:
-  Option<std::string> rocmPath{*this, "rocm-path",
-                               llvm::cl::desc("Path to ROCm install")};
-
-  // Overload to allow linking in device libs
-  std::unique_ptr<llvm::Module>
-  translateToLLVMIR(llvm::LLVMContext &llvmContext) override;
-
-private:
-  // Loads LLVM bitcode libraries
-  std::optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
-  loadLibraries(SmallVectorImpl<char> &path,
-                SmallVectorImpl<StringRef> &libraries,
-                llvm::LLVMContext &context);
-
-  // Serializes ROCDL to HSACO.
-  std::unique_ptr<std::vector<char>>
-  serializeISA(const std::string &isa) override;
-
-  LogicalResult assembleIsa(const std::string &isa,
-                            SmallVectorImpl<char> &result);
-  std::unique_ptr<std::vector<char>> createHsaco(ArrayRef<char> isaBinary);
-
-  std::string getRocmPath();
-};
-} // namespace
-
-SerializeToHsacoPass::SerializeToHsacoPass(const SerializeToHsacoPass &other)
-    : PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass>(other) {}
-
-/// Get a user-specified path to ROCm
-// Tries, in order, the --rocm-path option, the ROCM_PATH environment variable
-// and a compile-time default
-std::string SerializeToHsacoPass::getRocmPath() {
-  if (rocmPath.getNumOccurrences() > 0)
-    return rocmPath.getValue();
-
-  return __DEFAULT_ROCM_PATH__;
-}
-
-// Sets the 'option' to 'value' unless it already has a value.
-static void maybeSetOption(Pass::Option<std::string> &option,
-                           function_ref<std::string()> getValue) {
-  if (!option.hasValue())
-    option = getValue();
-}
-
-llvm::once_flag SerializeToHsacoPass::initializeBackendOnce;
-
-SerializeToHsacoPass::SerializeToHsacoPass(StringRef triple, StringRef arch,
-                                           StringRef features, int optLevel) {
-  // No matter how this pass is constructed, ensure that the AMDGPU backend
-  // is initialized exactly once.
-  llvm::call_once(initializeBackendOnce, []() {
-    // Initialize LLVM AMDGPU backend.
-    LLVMInitializeAMDGPUAsmParser();
-    LLVMInitializeAMDGPUAsmPrinter();
-    LLVMInitializeAMDGPUTarget();
-    LLVMInitializeAMDGPUTargetInfo();
-    LLVMInitializeAMDGPUTargetMC();
-  });
-  maybeSetOption(this->triple, [&triple] { return triple.str(); });
-  maybeSetOption(this->chip, [&arch] { return arch.str(); });
-  maybeSetOption(this->features, [&features] { return features.str(); });
-  if (this->optLevel.getNumOccurrences() == 0)
-    this->optLevel.setValue(optLevel);
-}
-
-std::optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
-SerializeToHsacoPass::loadLibraries(SmallVectorImpl<char> &path,
-                                    SmallVectorImpl<StringRef> &libraries,
-                                    llvm::LLVMContext &context) {
-  SmallVector<std::unique_ptr<llvm::Module>, 3> ret;
-  size_t dirLength = path.size();
-
-  if (!llvm::sys::fs::is_directory(path)) {
-    getOperation().emitRemark() << "Bitcode path: " << path
-                                << " does not exist or is not a directory\n";
-    return std::nullopt;
-  }
-
-  for (const StringRef file : libraries) {
-    llvm::SMDiagnostic error;
-    llvm::sys::path::append(path, file);
-    llvm::StringRef pathRef(path.data(), path.size());
-    std::unique_ptr<llvm::Module> library =
-        llvm::getLazyIRFileModule(pathRef, error, context);
-    path.truncate(dirLength);
-    if (!library) {
-      getOperation().emitError() << "Failed to load library " << file
-                                 << " from " << path << error.getMessage();
-      return std::nullopt;
-    }
-    // Some ROCM builds don't strip this like they should
-    if (auto *openclVersion = library->getNamedMetadata("opencl.ocl.version"))
-      library->eraseNamedMetadata(openclVersion);
-    // Stop spamming us with clang version numbers
-    if (auto *ident = library->getNamedMetadata("llvm.ident"))
-      library->eraseNamedMetadata(ident);
-    ret.push_back(std::move(library));
-  }
-
-  return std::move(ret);
-}
-
-std::unique_ptr<llvm::Module>
-SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
-  // MLIR -> LLVM translation
-  std::unique_ptr<llvm::Module> ret =
-      gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext);
-
-  if (!ret) {
-    getOperation().emitOpError("Module lowering failed");
-    return ret;
-  }
-  // Walk the LLVM module in order to determine if we need to link in device
-  // libs
-  bool needOpenCl = false;
-  bool needOckl = false;
-  bool needOcml = false;
-  for (llvm::Function &f : ret->functions()) {
-    if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) {
-      StringRef funcName = f.getName();
-      if ("printf" == funcName)
-        needOpenCl = true;
-      if (funcName.starts_with("__ockl_"))
-        needOckl = true;
-      if (funcName.starts_with("__ocml_"))
-        needOcml = true;
-    }
-  }
-
-  if (needOpenCl)
-    needOcml = needOckl = true;
-
-  // No libraries needed (the typical case)
-  if (!(needOpenCl || needOcml || needOckl))
-    return ret;
-
-  // Define one of the control constants the ROCm device libraries expect to be
-  // present These constants can either be defined in the module or can be
-  // imported by linking in bitcode that defines the constant. To simplify our
-  // logic, we define the constants into the module we are compiling
-  auto addControlConstant = [&module = *ret](StringRef name, uint32_t value,
-                                             uint32_t bitwidth) {
-    using llvm::GlobalVariable;
-    if (module.getNamedGlobal(name)) {
-      return;
-    }
-    llvm::IntegerType *type =
-        llvm::IntegerType::getIntNTy(module.getContext(), bitwidth);
-    auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false);
-    auto *constant = new GlobalVariable(
-        module, type,
-        /*isConstant=*/true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage,
-        initializer, name,
-        /*before=*/nullptr,
-        /*threadLocalMode=*/GlobalVariable::ThreadLocalMode::NotThreadLocal,
-        /*addressSpace=*/4);
-    constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local);
-    constant->setVisibility(
-        GlobalVariable::VisibilityTypes::ProtectedVisibility);
-    constant->setAlignment(llvm::MaybeAlign(bitwidth / 8));
-  };
-
-  // Set up control variables in the module instead of linking in tiny bitcode
-  if (needOcml) {
-    // TODO(kdrewnia): Enable math optimizations once we have support for
-    // `-ffast-math`-like options
-    addControlConstant("__oclc_finite_only_opt", 0, 8);
-    addControlConstant("__oclc_daz_opt", 0, 8);
-    addControlConstant("__oclc_correctly_rounded_sqrt32", 1, 8);
-    addControlConstant("__oclc_unsafe_math_opt", 0, 8);
-  }
-  if (needOcml || needOckl) {
-    addControlConstant("__oclc_wavefrontsize64", 1, 8);
-    StringRef chipSet = this->chip.getValue();
-    if (chipSet.starts_with("gfx"))
-      chipSet = chipSet.substr(3);
-    uint32_t minor =
-        llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue();
-    uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10)
-                         .getZExtValue();
-    uint32_t isaNumber = minor + 1000 * major;
-    addControlConstant("__oclc_ISA_version", isaNumber, 32);
-
-    // This constant must always match the default code object ABI version
-    // of the AMDGPU backend.
-    addControlConstant("__oclc_ABI_version", 500, 32);
-  }
-
-  // Determine libraries we need to link - order matters due to dependencies
-  llvm::SmallVector<StringRef, 4> libraries;
-  if (needOpenCl)
-    libraries.push_back("opencl.bc");
-  if (needOcml)
-    libraries.push_back("ocml.bc");
-  if (needOckl)
-    libraries.push_back("ockl.bc");
-
-  std::optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> mbModules;
-  std::string theRocmPath = getRocmPath();
-  llvm::SmallString<32> bitcodePath(theRocmPath);
-  llvm::sys::path::append(bitcodePath, "amdgcn", "bitcode");
-  mbModules = loadLibraries(bitcodePath, libraries, llvmContext);
-
-  if (!mbModules) {
-    getOperation()
-            .emitWarning("Could not load required device libraries")
-            .attachNote()
-        << "This will probably cause link-time or run-time failures";
-    return ret; // We can still abort here
-  }
-
-  llvm::Linker linker(*ret);
-  for (std::unique_ptr<llvm::Module> &libModule : *mbModules) {
-    // This bitcode linking code is substantially similar to what is used in
-    // hip-clang It imports the library functions into the module, allowing LLVM
-    // optimization passes (which must run after linking) to optimize across the
-    // libraries and the module's code. We also only import symbols if they are
-    // referenced by the module or a previous library since there will be no
-    // other source of references to those symbols in this compilation and since
-    // we don't want to bloat the resulting code object.
-    bool err = linker.linkInModule(
-        std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded,
-        [](llvm::Module &m, const StringSet<> &gvs) {
-          llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) {
-            return !gv.hasName() || (gvs.count(gv.getName()) == 0);
-          });
-        });
-    // True is linker failure
-    if (err) {
-      getOperation().emitError(
-          "Unrecoverable failure during device library linking.");
-      // We have no guaranties about the state of `ret`, so bail
-      return nullptr;
-    }
-  }
-
-  return ret;
-}
-
-LogicalResult SerializeToHsacoPass::assembleIsa(const std::string &isa,
-                                                SmallVectorImpl<char> &result) {
-  auto loc = getOperation().getLoc();
-
-  llvm::raw_svector_ostream os(result);
-
-  llvm::Triple triple(llvm::Triple::normalize(this->triple));
-  std::string error;
-  const llvm::Target *target =
-      llvm::TargetRegistry::lookupTarget(triple.normalize(), error);
-  if (!target)
-    return emitError(loc, Twine("failed to lookup target: ") + error);
-
-  llvm::SourceMgr srcMgr;
-  srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(isa), SMLoc());
-
-  const llvm::MCTargetOptions mcOptions;
-  std::unique_ptr<llvm::MCRegisterInfo> mri(
-      target->createMCRegInfo(this->triple));
-  std::unique_ptr<llvm::MCAsmInfo> mai(
-      target->createMCAsmInfo(*mri, this->triple, mcOptions));
-  std::unique_ptr<llvm::MCSubtargetInfo> sti(
-      target->createMCSubtargetInfo(this->triple, this->chip, this->features));
-
-  llvm::MCContext ctx(triple, mai.get(), mri.get(), sti.get(), &srcMgr,
-                      &mcOptions);
-  std::unique_ptr<llvm::MCObjectFileInfo> mofi(target->createMCObjectFileInfo(
-      ctx, /*PIC=*/false, /*LargeCodeModel=*/false));
-  ctx.setObjectFileInfo(mofi.get());
-
-  SmallString<128> cwd;
-  if (!llvm::sys::fs::current_path(cwd))
-    ctx.setCompilationDir(cwd);
-
-  std::unique_ptr<llvm::MCStreamer> mcStreamer;
-  std::unique_ptr<llvm::MCInstrInfo> mcii(target->createMCInstrInfo());
-
-  llvm::MCCodeEmitter *ce = target->createMCCodeEmitter(*mcii, ctx);
-  llvm::MCAsmBackend *mab = target->createMCAsmBackend(*sti, *mri, mcOptions);
-  mcStreamer.reset(target->createMCObjectStreamer(
-      triple, ctx, std::unique_ptr<llvm::MCAsmBackend>(mab),
-      mab->createObjectWriter(os), std::unique_ptr<llvm::MCCodeEmitter>(ce),
-      *sti, mcOptions.MCRelaxAll, mcOptions.MCIncrementalLinkerCompatible,
-      /*DWARFMustBeAtTheEnd*/ false));
-
-  std::unique_ptr<llvm::MCAsmParser> parser(
-      createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
-  std::unique_ptr<llvm::MCTargetAsmParser> tap(
-      target->createMCAsmParser(*sti, *parser, *mcii, mcOptions));
-
-  if (!tap)
-    return emitError(loc, "assembler initialization error");
-
-  parser->setTargetParser(*tap);
-  parser->Run(false);
-
-  return success();
-}
-
-std::unique_ptr<std::vector<char>>
-SerializeToHsacoPass::createHsaco(ArrayRef<char> isaBinary) {
-  auto loc = getOperation().getLoc();
-
-  // Save the ISA binary to a temp file.
-  int tempIsaBinaryFd = -1;
-  SmallString<128> tempIsaBinaryFilename;
-  if (llvm::sys::fs::createTemporaryFile("kernel", "o", tempIsaBinaryFd,
-                                         tempIsaBinaryFilename)) {
-    emitError(loc, "temporary file for ISA binary creation error");
-    return {};
-  }
-  llvm::FileRemover cleanupIsaBinary(tempIsaBinaryFilename);
-  llvm::raw_fd_ostream tempIsaBinaryOs(tempIsaBinaryFd, true);
-  tempIsaBinaryOs << StringRef(isaBinary.data(), isaBinary.size());
-  tempIsaBinaryOs.close();
-
-  // Create a temp file for HSA code object.
-  SmallString<128> tempHsacoFilename;
-  if (llvm::sys::fs::createTemporaryFile("kernel", "hsaco",
-                                         tempHsacoFilename)) {
-    emitError(loc, "temporary file for HSA code object creation error");
-    return {};
-  }
-  llvm::FileRemover cleanupHsaco(tempHsacoFilename);
-
-  std::string theRocmPath = getRocmPath();
-  llvm::SmallString<32> lldPath(theRocmPath);
-  llvm::sys::path::append(lldPath, "llvm", "bin", "ld.lld");
-  int lldResult = llvm::sys::ExecuteAndWait(
-      lldPath,
-      {"ld.lld", "-shared", tempIsaBinaryFilename, "-o", tempHsacoFilename});
-  if (lldResult != 0) {
-    emitError(loc, "lld invocation error");
-    return {};
-  }
-
-  // Load the HSA code object.
-  auto hsacoFile =
-      llvm::MemoryBuffer::getFile(tempHsacoFilename, /*IsText=*/false);
-  if (!hsacoFile) {
-    emitError(loc, "read HSA code object from temp file error");
-    return {};
-  }
-
-  StringRef buffer = (*hsacoFile)->getBuffer();
-  return std::make_unique<std::vector<char>>(buffer.begin(), buffer.end());
-}
-
-std::unique_ptr<std::vector<char>>
-SerializeToHsacoPass::serializeISA(const std::string &isa) {
-  SmallVector<char, 0> isaBinary;
-  if (failed(assembleIsa(isa, isaBinary)))
-    return {};
-  return createHsaco(isaBinary);
-}
-
-// Register pass to serialize GPU kernel functions to a HSACO binary annotation.
-void mlir::registerGpuSerializeToHsacoPass() {
-  PassRegistration<SerializeToHsacoPass> registerSerializeToHSACO([] {
-    return std::make_unique<SerializeToHsacoPass>("amdgcn-amd-amdhsa", "", "",
-                                                  2);
-  });
-}
-
-/// Create an instance of the GPU kernel function to HSAco binary serialization
-/// pass.
-std::unique_ptr<Pass> mlir::createGpuSerializeToHsacoPass(StringRef triple,
-                                                          StringRef arch,
-                                                          StringRef features,
-                                                          int optLevel) {
-  return std::make_unique<SerializeToHsacoPass>(triple, arch, features,
-                                                optLevel);
-}
-
-#else  // MLIR_ENABLE_ROCM_CONVERSIONS
-void mlir::registerGpuSerializeToHsacoPass() {}
-#endif // MLIR_ENABLE_ROCM_CONVERSIONS

>From 8947c02c7a9e7aa0ab65794196813c74f45b8420 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Thu, 13 Jun 2024 15:45:17 +0000
Subject: [PATCH 2/2] update docs

---
 mlir/docs/Dialects/GPU.md | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/docs/Dialects/GPU.md b/mlir/docs/Dialects/GPU.md
index 8a3acc33600a4..8a0f117549c03 100644
--- a/mlir/docs/Dialects/GPU.md
+++ b/mlir/docs/Dialects/GPU.md
@@ -37,9 +37,6 @@ complex lifetime analysis following the principles of MLIR that promote
 structure and representing analysis results in the IR.
 
 ## GPU Compilation
-### Deprecation notice
-The `--gpu-to-(cubin|hsaco)` passes will be deprecated in a future release.
-
 ### Compilation overview
 The compilation process in the GPU dialect has two main stages: GPU module
 serialization and offloading operations translation. Together these stages can



More information about the Mlir-commits mailing list