[Mlir-commits] [mlir] a6f53af - [MLIR][GPU] Link in device libraries during HSA compilation if needed
Krzysztof Drewniak
llvmlistbot at llvm.org
Fri Nov 19 14:29:43 PST 2021
Author: Krzysztof Drewniak
Date: 2021-11-19T22:29:37Z
New Revision: a6f53afbcb4d995139064276b5ad971ad7ced5e2
URL: https://github.com/llvm/llvm-project/commit/a6f53afbcb4d995139064276b5ad971ad7ced5e2
DIFF: https://github.com/llvm/llvm-project/commit/a6f53afbcb4d995139064276b5ad971ad7ced5e2.diff
LOG: [MLIR][GPU] Link in device libraries during HSA compilation if needed
To perform some operations, such as sin() or printf(), code compiled
for AMD GPUs must be linked to a series of device libraries. This
commit adds support for linking in these libraries.
However, since these device libraries are delivered as LLVM bitcode,
raising the possibility of version incompatibilities, this commit only
links in libraries when the functions from those libraries are called
by the code being compiled.
This code also sets the math flags to their most conservative values,
as MLIR doesn't have a `-ffast-math` equivalent.
Depends on D114114
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D114117
Added:
Modified:
mlir/lib/Dialect/GPU/CMakeLists.txt
mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 285bc0f82a278..d7e011bf44115 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -8,6 +8,8 @@ endif()
if (MLIR_ENABLE_ROCM_CONVERSIONS)
set(AMDGPU_LIBS
+ IRReader
+ linker
MCParser
AMDGPUAsmParser
AMDGPUCodeGen
diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
index 023ac93f7a682..e8ad4ecd2593a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
@@ -21,6 +21,12 @@
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Linker/Linker.h"
+
#include "llvm/MC/MCAsmBackend.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCCodeEmitter.h"
@@ -42,6 +48,8 @@
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
+#include "llvm/Transforms/IPO/Internalize.h"
+
#include "lld/Common/Driver.h"
#include <mutex>
@@ -69,6 +77,10 @@ class SerializeToHsacoPass
Option<std::string> rocmPath{*this, "rocm-path",
llvm::cl::desc("Path to ROCm install")};
+ // Overload to allow linking in device libs
+ std::unique_ptr<llvm::Module>
+ translateToLLVMIR(llvm::LLVMContext &llvmContext) override;
+
/// Adds LLVM optimization passes
LogicalResult optimizeLlvm(llvm::Module &llvmModule,
llvm::TargetMachine &targetMachine) override;
@@ -76,6 +88,12 @@ class SerializeToHsacoPass
private:
void getDependentDialects(DialectRegistry ®istry) const override;
+ // Loads LLVM bitcode libraries
+ Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
+ loadLibraries(SmallVectorImpl<char> &path,
+ SmallVectorImpl<StringRef> &libraries,
+ llvm::LLVMContext &context);
+
// Serializes ROCDL to HSACO.
std::unique_ptr<std::vector<char>>
serializeISA(const std::string &isa) override;
@@ -123,6 +141,174 @@ void SerializeToHsacoPass::getDependentDialects(
gpu::SerializeToBlobPass::getDependentDialects(registry);
}
+Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
+SerializeToHsacoPass::loadLibraries(SmallVectorImpl<char> &path,
+ SmallVectorImpl<StringRef> &libraries,
+ llvm::LLVMContext &context) {
+ SmallVector<std::unique_ptr<llvm::Module>, 3> ret;
+ size_t dirLength = path.size();
+
+ if (!llvm::sys::fs::is_directory(path)) {
+ getOperation().emitRemark() << "Bitcode path: " << path
+ << " does not exist or is not a directory\n";
+ return llvm::None;
+ }
+
+ for (const StringRef file : libraries) {
+ llvm::SMDiagnostic error;
+ llvm::sys::path::append(path, file);
+ llvm::StringRef pathRef(path.data(), path.size());
+ std::unique_ptr<llvm::Module> library =
+ llvm::getLazyIRFileModule(pathRef, error, context);
+ path.set_size(dirLength);
+ if (!library) {
+ getOperation().emitError() << "Failed to load library " << file
+ << " from " << path << error.getMessage();
+ return llvm::None;
+ }
+ // Some ROCM builds don't strip this like they should
+ if (auto *openclVersion = library->getNamedMetadata("opencl.ocl.version"))
+ library->eraseNamedMetadata(openclVersion);
+ // Stop spamming us with clang version numbers
+ if (auto *ident = library->getNamedMetadata("llvm.ident"))
+ library->eraseNamedMetadata(ident);
+ ret.push_back(std::move(library));
+ }
+
+ return ret;
+}
+
+std::unique_ptr<llvm::Module>
+SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
+ // MLIR -> LLVM translation
+ std::unique_ptr<llvm::Module> ret =
+ gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext);
+
+ if (!ret) {
+ getOperation().emitOpError("Module lowering failed");
+ return ret;
+ }
+ // Walk the LLVM module in order to determine if we need to link in device
+ // libs
+ bool needOpenCl = false;
+ bool needOckl = false;
+ bool needOcml = false;
+ for (llvm::Function &f : ret->functions()) {
+ if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) {
+ StringRef funcName = f.getName();
+ if ("printf" == funcName)
+ needOpenCl = true;
+ if (funcName.startswith("__ockl_"))
+ needOckl = true;
+ if (funcName.startswith("__ocml_"))
+ needOcml = true;
+ }
+ }
+
+ if (needOpenCl)
+ needOcml = needOckl = true;
+
+ // No libraries needed (the typical case)
+ if (!(needOpenCl || needOcml || needOckl))
+ return ret;
+
+ // Define one of the control constants the ROCm device libraries expect to be
+ // present These constants can either be defined in the module or can be
+ // imported by linking in bitcode that defines the constant. To simplify our
+ // logic, we define the constants into the module we are compiling
+ auto addControlConstant = [&module = *ret](StringRef name, uint32_t value,
+ uint32_t bitwidth) {
+ using llvm::GlobalVariable;
+ if (module.getNamedGlobal(name)) {
+ return;
+ }
+ llvm::IntegerType *type =
+ llvm::IntegerType::getIntNTy(module.getContext(), bitwidth);
+ auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false);
+ auto *constant = new GlobalVariable(
+ module, type,
+ /*isConstant=*/true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage,
+ initializer, name,
+ /*before=*/nullptr,
+ /*threadLocalMode=*/GlobalVariable::ThreadLocalMode::NotThreadLocal,
+ /*addressSpace=*/4);
+ constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local);
+ constant->setVisibility(
+ GlobalVariable::VisibilityTypes::ProtectedVisibility);
+ constant->setAlignment(llvm::MaybeAlign(bitwidth / 8));
+ };
+
+ if (needOcml) {
+ // TODO(kdrewnia): Enable math optimizations once we have support for
+ // `-ffast-math`-like options
+ addControlConstant("__oclc_finite_only_opt", 0, 8);
+ addControlConstant("__oclc_daz_opt", 0, 8);
+ addControlConstant("__oclc_correctly_rounded_sqrt32", 1, 8);
+ addControlConstant("__oclc_unsafe_math_opt", 0, 8);
+ }
+ if (needOcml || needOckl) {
+ addControlConstant("__oclc_wavefrontsize64", 1, 8);
+ StringRef chipSet = this->chip.getValue();
+ if (chipSet.startswith("gfx"))
+ chipSet = chipSet.substr(3);
+ uint32_t minor =
+ llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue();
+ uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10)
+ .getZExtValue();
+ uint32_t isaNumber = minor + 1000 * major;
+ addControlConstant("__oclc_ISA_version", isaNumber, 32);
+ }
+
+ // Determine libraries we need to link - order matters due to dependencies
+ llvm::SmallVector<StringRef, 4> libraries;
+ if (needOpenCl)
+ libraries.push_back("opencl.bc");
+ if (needOcml)
+ libraries.push_back("ocml.bc");
+ if (needOckl)
+ libraries.push_back("ockl.bc");
+
+ Optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> mbModules;
+ std::string theRocmPath = getRocmPath();
+ llvm::SmallString<32> bitcodePath(std::move(theRocmPath));
+ llvm::sys::path::append(bitcodePath, "amdgcn", "bitcode");
+ mbModules = loadLibraries(bitcodePath, libraries, llvmContext);
+
+ if (!mbModules) {
+ getOperation()
+ .emitWarning("Could not load required device labraries")
+ .attachNote()
+ << "This will probably cause link-time or run-time failures";
+ return ret; // We can still abort here
+ }
+
+ llvm::Linker linker(*ret);
+ for (std::unique_ptr<llvm::Module> &libModule : mbModules.getValue()) {
+ // This bitcode linking code is substantially similar to what is used in
+ // hip-clang It imports the library functions into the module, allowing LLVM
+ // optimization passes (which must run after linking) to optimize across the
+ // libraries and the module's code. We also only import symbols if they are
+ // referenced by the module or a previous library since there will be no
+ // other source of references to those symbols in this compilation and since
+ // we don't want to bloat the resulting code object.
+ bool err = linker.linkInModule(
+ std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded,
+ [](llvm::Module &m, const StringSet<> &gvs) {
+ llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) {
+ return !gv.hasName() || (gvs.count(gv.getName()) == 0);
+ });
+ });
+ // True is linker failure
+ if (err) {
+ getOperation().emitError(
+ "Unrecoverable failure during device library linking.");
+ // We have no guaranties about the state of `ret`, so bail
+ return nullptr;
+ }
+ }
+ return ret;
+}
+
LogicalResult
SerializeToHsacoPass::optimizeLlvm(llvm::Module &llvmModule,
llvm::TargetMachine &targetMachine) {
More information about the Mlir-commits
mailing list