[Mlir-commits] [mlir] a6f53af - [MLIR][GPU] Link in device libraries during HSA compilation if needed

Krzysztof Drewniak llvmlistbot at llvm.org
Fri Nov 19 14:29:43 PST 2021


Author: Krzysztof Drewniak
Date: 2021-11-19T22:29:37Z
New Revision: a6f53afbcb4d995139064276b5ad971ad7ced5e2

URL: https://github.com/llvm/llvm-project/commit/a6f53afbcb4d995139064276b5ad971ad7ced5e2
DIFF: https://github.com/llvm/llvm-project/commit/a6f53afbcb4d995139064276b5ad971ad7ced5e2.diff

LOG: [MLIR][GPU] Link in device libraries during HSA compilation if needed

To perform some operations, such as sin() or printf(), code compiled
for AMD GPUs must be linked to a series of device libraries. This
commit adds support for linking in these libraries.

However, since these device libraries are delivered as LLVM bitcode,
raising the possibility of version incompatibilities, this commit only
links in libraries when the functions from those libraries are called
by the code being compiled.

This code also sets the math flags to their most conservative values,
as MLIR doesn't have a `-ffast-math` equivalent.

Depends on D114114

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D114117

Added: 
    

Modified: 
    mlir/lib/Dialect/GPU/CMakeLists.txt
    mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 285bc0f82a278..d7e011bf44115 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -8,6 +8,8 @@ endif()
 
 if (MLIR_ENABLE_ROCM_CONVERSIONS)
   set(AMDGPU_LIBS
+    IRReader
+    linker
     MCParser
     AMDGPUAsmParser
     AMDGPUCodeGen

diff  --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
index 023ac93f7a682..e8ad4ecd2593a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
@@ -21,6 +21,12 @@
 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Export.h"
 
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Linker/Linker.h"
+
 #include "llvm/MC/MCAsmBackend.h"
 #include "llvm/MC/MCAsmInfo.h"
 #include "llvm/MC/MCCodeEmitter.h"
@@ -42,6 +48,8 @@
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
 
+#include "llvm/Transforms/IPO/Internalize.h"
+
 #include "lld/Common/Driver.h"
 
 #include <mutex>
@@ -69,6 +77,10 @@ class SerializeToHsacoPass
   Option<std::string> rocmPath{*this, "rocm-path",
                                llvm::cl::desc("Path to ROCm install")};
 
+  // Overload to allow linking in device libs
+  std::unique_ptr<llvm::Module>
+  translateToLLVMIR(llvm::LLVMContext &llvmContext) override;
+
   /// Adds LLVM optimization passes
   LogicalResult optimizeLlvm(llvm::Module &llvmModule,
                              llvm::TargetMachine &targetMachine) override;
@@ -76,6 +88,12 @@ class SerializeToHsacoPass
 private:
   void getDependentDialects(DialectRegistry &registry) const override;
 
+  // Loads LLVM bitcode libraries
+  Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
+  loadLibraries(SmallVectorImpl<char> &path,
+                SmallVectorImpl<StringRef> &libraries,
+                llvm::LLVMContext &context);
+
   // Serializes ROCDL to HSACO.
   std::unique_ptr<std::vector<char>>
   serializeISA(const std::string &isa) override;
@@ -123,6 +141,174 @@ void SerializeToHsacoPass::getDependentDialects(
   gpu::SerializeToBlobPass::getDependentDialects(registry);
 }
 
+Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
+SerializeToHsacoPass::loadLibraries(SmallVectorImpl<char> &path,
+                                    SmallVectorImpl<StringRef> &libraries,
+                                    llvm::LLVMContext &context) {
+  SmallVector<std::unique_ptr<llvm::Module>, 3> ret;
+  size_t dirLength = path.size();
+
+  if (!llvm::sys::fs::is_directory(path)) {
+    getOperation().emitRemark() << "Bitcode path: " << path
+                                << " does not exist or is not a directory\n";
+    return llvm::None;
+  }
+
+  for (const StringRef file : libraries) {
+    llvm::SMDiagnostic error;
+    llvm::sys::path::append(path, file);
+    llvm::StringRef pathRef(path.data(), path.size());
+    std::unique_ptr<llvm::Module> library =
+        llvm::getLazyIRFileModule(pathRef, error, context);
+    path.set_size(dirLength);
+    if (!library) {
+      getOperation().emitError() << "Failed to load library " << file
+                                 << " from " << path << error.getMessage();
+      return llvm::None;
+    }
+    // Some ROCM builds don't strip this like they should
+    if (auto *openclVersion = library->getNamedMetadata("opencl.ocl.version"))
+      library->eraseNamedMetadata(openclVersion);
+    // Stop spamming us with clang version numbers
+    if (auto *ident = library->getNamedMetadata("llvm.ident"))
+      library->eraseNamedMetadata(ident);
+    ret.push_back(std::move(library));
+  }
+
+  return ret;
+}
+
+std::unique_ptr<llvm::Module>
+SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
+  // MLIR -> LLVM translation
+  std::unique_ptr<llvm::Module> ret =
+      gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext);
+
+  if (!ret) {
+    getOperation().emitOpError("Module lowering failed");
+    return ret;
+  }
+  // Walk the LLVM module in order to determine if we need to link in device
+  // libs
+  bool needOpenCl = false;
+  bool needOckl = false;
+  bool needOcml = false;
+  for (llvm::Function &f : ret->functions()) {
+    if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) {
+      StringRef funcName = f.getName();
+      if ("printf" == funcName)
+        needOpenCl = true;
+      if (funcName.startswith("__ockl_"))
+        needOckl = true;
+      if (funcName.startswith("__ocml_"))
+        needOcml = true;
+    }
+  }
+
+  if (needOpenCl)
+    needOcml = needOckl = true;
+
+  // No libraries needed (the typical case)
+  if (!(needOpenCl || needOcml || needOckl))
+    return ret;
+
+  // Define one of the control constants the ROCm device libraries expect to be
+  // present These constants can either be defined in the module or can be
+  // imported by linking in bitcode that defines the constant. To simplify our
+  // logic, we define the constants into the module we are compiling
+  auto addControlConstant = [&module = *ret](StringRef name, uint32_t value,
+                                             uint32_t bitwidth) {
+    using llvm::GlobalVariable;
+    if (module.getNamedGlobal(name)) {
+      return;
+    }
+    llvm::IntegerType *type =
+        llvm::IntegerType::getIntNTy(module.getContext(), bitwidth);
+    auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false);
+    auto *constant = new GlobalVariable(
+        module, type,
+        /*isConstant=*/true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage,
+        initializer, name,
+        /*before=*/nullptr,
+        /*threadLocalMode=*/GlobalVariable::ThreadLocalMode::NotThreadLocal,
+        /*addressSpace=*/4);
+    constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local);
+    constant->setVisibility(
+        GlobalVariable::VisibilityTypes::ProtectedVisibility);
+    constant->setAlignment(llvm::MaybeAlign(bitwidth / 8));
+  };
+
+  if (needOcml) {
+    // TODO(kdrewnia): Enable math optimizations once we have support for
+    // `-ffast-math`-like options
+    addControlConstant("__oclc_finite_only_opt", 0, 8);
+    addControlConstant("__oclc_daz_opt", 0, 8);
+    addControlConstant("__oclc_correctly_rounded_sqrt32", 1, 8);
+    addControlConstant("__oclc_unsafe_math_opt", 0, 8);
+  }
+  if (needOcml || needOckl) {
+    addControlConstant("__oclc_wavefrontsize64", 1, 8);
+    StringRef chipSet = this->chip.getValue();
+    if (chipSet.startswith("gfx"))
+      chipSet = chipSet.substr(3);
+    uint32_t minor =
+        llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue();
+    uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10)
+                         .getZExtValue();
+    uint32_t isaNumber = minor + 1000 * major;
+    addControlConstant("__oclc_ISA_version", isaNumber, 32);
+  }
+
+  // Determine libraries we need to link - order matters due to dependencies
+  llvm::SmallVector<StringRef, 4> libraries;
+  if (needOpenCl)
+    libraries.push_back("opencl.bc");
+  if (needOcml)
+    libraries.push_back("ocml.bc");
+  if (needOckl)
+    libraries.push_back("ockl.bc");
+
+  Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> mbModules;
+  std::string theRocmPath = getRocmPath();
+  llvm::SmallString<32> bitcodePath(std::move(theRocmPath));
+  llvm::sys::path::append(bitcodePath, "amdgcn", "bitcode");
+  mbModules = loadLibraries(bitcodePath, libraries, llvmContext);
+
+  if (!mbModules) {
+    getOperation()
+            .emitWarning("Could not load required device labraries")
+            .attachNote()
+        << "This will probably cause link-time or run-time failures";
+    return ret; // We can still abort here
+  }
+
+  llvm::Linker linker(*ret);
+  for (std::unique_ptr<llvm::Module> &libModule : mbModules.getValue()) {
+    // This bitcode linking code is substantially similar to what is used in
+    // hip-clang It imports the library functions into the module, allowing LLVM
+    // optimization passes (which must run after linking) to optimize across the
+    // libraries and the module's code. We also only import symbols if they are
+    // referenced by the module or a previous library since there will be no
+    // other source of references to those symbols in this compilation and since
+    // we don't want to bloat the resulting code object.
+    bool err = linker.linkInModule(
+        std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded,
+        [](llvm::Module &m, const StringSet<> &gvs) {
+          llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) {
+            return !gv.hasName() || (gvs.count(gv.getName()) == 0);
+          });
+        });
+    // True is linker failure
+    if (err) {
+      getOperation().emitError(
+          "Unrecoverable failure during device library linking.");
+      // We have no guaranties about the state of `ret`, so bail
+      return nullptr;
+    }
+  }
+  return ret;
+}
+
 LogicalResult
 SerializeToHsacoPass::optimizeLlvm(llvm::Module &llvmModule,
                                    llvm::TargetMachine &targetMachine) {


        


More information about the Mlir-commits mailing list