[Mlir-commits] [mlir] [MLIR][NVVM] Add binaryCallback (PR #170853)
Guray Ozen
llvmlistbot at llvm.org
Tue Dec 9 01:08:21 PST 2025
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/170853
>From d05967081850a7caa86eaf71eb3c8b40cbbcee37 Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Tue, 9 Dec 2025 10:08:10 +0100
Subject: [PATCH] [MLIR][NVVM] Add binaryCallback
---
.../Dialect/GPU/IR/CompilationInterfaces.h | 12 ++++-
.../include/mlir/Target/LLVM/ModuleToObject.h | 6 ++-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 19 +++++--
mlir/lib/Target/LLVM/ModuleToObject.cpp | 6 ++-
mlir/lib/Target/LLVM/NVVM/Target.cpp | 51 ++++++++++++++-----
.../Target/LLVM/SerializeNVVMTarget.cpp | 34 +++++++++++++
6 files changed, 106 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 139360f8bd3fc..0e5111583f456 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -58,7 +58,8 @@ class TargetOptions {
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<void(StringRef)> isaCallback = {},
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -111,6 +112,9 @@ class TargetOptions {
/// for example PTX assembly.
function_ref<void(StringRef)> getISACallback() const;
+ /// Returns the callback invoked with the compilation target for the device.
+ function_ref<void(StringRef)> getbinaryCompilerDiagnosticCallback() const;
+
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -130,7 +134,8 @@ class TargetOptions {
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<void(StringRef)> isaCallback = {},
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -167,6 +172,9 @@ class TargetOptions {
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;
+ /// Callback invoked with the compilation target for the device.
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback;
+
private:
TypeID typeID;
};
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 11fea6f0a4443..4c23feeccc75c 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -35,7 +35,8 @@ class ModuleToObject {
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<void(StringRef)> isaCallback = {},
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -134,6 +135,9 @@ class ModuleToObject {
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;
+ /// Callback for diagnostic messages from the binary compiler.
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback;
+
private:
/// The TargetMachine created for the given Triple, if available.
/// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 61a630aa88960..7734a3e63842d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2655,12 +2655,13 @@ TargetOptions::TargetOptions(
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<void(StringRef)> isaCallback,
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
- linkedLlvmIRCallback, optimizedLlvmIRCallback,
- isaCallback) {}
+ linkedLlvmIRCallback, optimizedLlvmIRCallback, isaCallback,
+ binaryCompilerDiagnosticCallback) {}
TargetOptions::TargetOptions(
TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
@@ -2670,7 +2671,8 @@ TargetOptions::TargetOptions(
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<void(StringRef)> isaCallback,
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback)
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
@@ -2678,7 +2680,9 @@ TargetOptions::TargetOptions(
initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
optimizedLlvmIRCallback(optimizedLlvmIRCallback),
- isaCallback(isaCallback), typeID(typeID) {}
+ isaCallback(isaCallback),
+ binaryCompilerDiagnosticCallback(binaryCompilerDiagnosticCallback),
+ typeID(typeID) {}
TypeID TargetOptions::getTypeID() const { return typeID; }
@@ -2715,6 +2719,11 @@ function_ref<void(StringRef)> TargetOptions::getISACallback() const {
return isaCallback;
}
+function_ref<void(StringRef)>
+TargetOptions::getbinaryCompilerDiagnosticCallback() const {
+ return binaryCompilerDiagnosticCallback;
+}
+
CompilationTarget TargetOptions::getCompilationTarget() const {
return compilationTarget;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 4098ccc548dc1..537643035e4cf 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -39,12 +39,14 @@ ModuleToObject::ModuleToObject(
int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<void(StringRef)> isaCallback,
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
optimizedLlvmIRCallback(optimizedLlvmIRCallback),
- isaCallback(isaCallback) {}
+ isaCallback(isaCallback),
+ binaryCompilerDiagnosticCallback(binaryCompilerDiagnosticCallback) {}
ModuleToObject::~ModuleToObject() = default;
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 8760ea8588e2c..2049dea3ea9e0 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -24,6 +24,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/ADT/ScopeExit.h"
@@ -103,7 +104,8 @@ SerializeGPUModuleBase::SerializeGPUModuleBase(
targetOptions.getInitialLlvmIRCallback(),
targetOptions.getLinkedLlvmIRCallback(),
targetOptions.getOptimizedLlvmIRCallback(),
- targetOptions.getISACallback()),
+ targetOptions.getISACallback(),
+ targetOptions.getbinaryCompilerDiagnosticCallback()),
target(target), toolkitPath(targetOptions.getToolkitPath()),
librariesToLink(targetOptions.getLibrariesToLink()) {
@@ -212,12 +214,14 @@ class NVPTXSerializer : public SerializeGPUModuleBase {
gpu::GPUModuleOp getOperation();
/// Compiles PTX to cubin using `ptxas`.
- std::optional<SmallVector<char, 0>>
- compileToBinary(const std::string &ptxCode);
+ std::optional<SmallVector<char, 0>> compileToBinary(
+ const std::string &ptxCode,
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback);
/// Compiles PTX to cubin using the `nvptxcompiler` library.
- std::optional<SmallVector<char, 0>>
- compileToBinaryNVPTX(const std::string &ptxCode);
+ std::optional<SmallVector<char, 0>> compileToBinaryNVPTX(
+ const std::string &ptxCode,
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback);
/// Serializes the LLVM module to an object format, depending on the
/// compilation target selected in target options.
@@ -346,13 +350,13 @@ static void setOptionalCommandlineArguments(NVVMTargetAttr target,
// TODO: clean this method & have a generic tool driver or never emit binaries
// with this mechanism and let another stage take care of it.
-std::optional<SmallVector<char, 0>>
-NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
+std::optional<SmallVector<char, 0>> NVPTXSerializer::compileToBinary(
+ const std::string &ptxCode,
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback) {
// Determine if the serializer should create a fatbinary with the PTX embeded
// or a simple CUBIN binary.
const bool createFatbin =
targetOptions.getCompilationTarget() == gpu::CompilationTarget::Fatbin;
-
// Find the `ptxas` & `fatbinary` tools.
std::optional<std::string> ptxasCompiler = findTool("ptxas");
if (!ptxasCompiler)
@@ -521,6 +525,15 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
/*ErrMsg=*/&message))
return emitLogError("`fatbinary`");
+ if (binaryCompilerDiagnosticCallback) {
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
+ llvm::MemoryBuffer::getFile(logFile->first);
+ if (logBuffer && !(*logBuffer)->getBuffer().empty()) {
+ StringRef logRef = (*logBuffer)->getBuffer();
+ binaryCompilerDiagnosticCallback(logRef);
+ }
+ }
+
// Dump the output of the tools, helpful if the verbose flag was passed.
#define DEBUG_TYPE "serialize-to-binary"
LLVM_DEBUG({
@@ -569,8 +582,9 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
} \
} while (false)
-std::optional<SmallVector<char, 0>>
-NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
+std::optional<SmallVector<char, 0>> NVPTXSerializer::compileToBinaryNVPTX(
+ const std::string &ptxCode,
+ function_ref<void(StringRef)> binaryCompilerDiagnosticCallback) {
Location loc = getOperation().getLoc();
nvPTXCompilerHandle compiler = nullptr;
nvPTXCompileResult status;
@@ -618,6 +632,18 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
RETURN_ON_NVPTXCOMPILER_ERROR(
nvPTXCompilerGetCompiledProgram(compiler, (void *)binary.data()));
+ if (binaryCompilerDiagnosticCallback) {
+ RETURN_ON_NVPTXCOMPILER_ERROR(
+ nvPTXCompilerGetInfoLogSize(compiler, &logSize));
+ if (logSize != 0) {
+ SmallVector<char> log(logSize + 1, 0);
+ RETURN_ON_NVPTXCOMPILER_ERROR(
+ nvPTXCompilerGetInfoLog(compiler, log.data()));
+ StringRef logRef(log.data(), log.size());
+ binaryCompilerDiagnosticCallback(logRef);
+ }
+ }
+
// Dump the log of the compiler, helpful if the verbose flag was passed.
#define DEBUG_TYPE "serialize-to-binary"
LLVM_DEBUG({
@@ -723,9 +749,10 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
moduleToObjectTimer.startTimer();
// Compile to binary.
#if MLIR_ENABLE_NVPTXCOMPILER
- result = compileToBinaryNVPTX(*serializedISA);
+ result =
+ compileToBinaryNVPTX(*serializedISA, binaryCompilerDiagnosticCallback);
#else
- result = compileToBinary(*serializedISA);
+ result = compileToBinary(*serializedISA, binaryCompilerDiagnosticCallback);
#endif // MLIR_ENABLE_NVPTXCOMPILER
moduleToObjectTimer.stopTimer();
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index af0af89c7d07e..f31b8a7c0619a 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -220,6 +220,40 @@ TEST_F(MLIRTargetLLVMNVVM,
}
}
+// Test NVVM serialization to Binary.
+TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackForBinary)) {
+ if (!hasPtxas())
+ GTEST_SKIP() << "PTXAS compiler not found, skipping test.";
+
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+
+ // Create an NVVM target.
+ NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+ std::string binaryResult;
+ auto binaryCompilerDiagnosticCallback =
+ [&binaryResult](llvm::StringRef binaryTarget) {
+ binaryResult = binaryTarget.str();
+ };
+ // Serialize the module.
+ auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+ ASSERT_TRUE(!!serializer);
+ gpu::TargetOptions options("", {}, "-v", "", gpu::CompilationTarget::Binary,
+ {}, {}, {}, {}, {},
+ binaryCompilerDiagnosticCallback);
+ for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+ std::optional<SmallVector<char, 0>> object =
+ serializer.serializeToObject(gpuModule, options);
+ // Check that the serializer was successful.
+ ASSERT_TRUE(object != std::nullopt);
+ ASSERT_TRUE(!object->empty());
+ ASSERT_TRUE(llvm::StringRef(binaryResult).contains("ptxas info"));
+ }
+}
+
// Test linking LLVM IR from a resource attribute.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
MLIRContext context(registry);
More information about the Mlir-commits
mailing list