[Mlir-commits] [mlir] [MLIR][NVVM] Add binaryCallback (PR #170853)
Guray Ozen
llvmlistbot at llvm.org
Mon Dec 8 05:47:19 PST 2025
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/170853
>From cf1731aba1416ae31a025b8a4617601cad8d1b2b Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Mon, 8 Dec 2025 14:47:08 +0100
Subject: [PATCH] [MLIR][NVVM] Add binaryCallback
---
.../Dialect/GPU/IR/CompilationInterfaces.h | 12 ++++-
.../include/mlir/Target/LLVM/ModuleToObject.h | 6 ++-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 17 ++++--
mlir/lib/Target/LLVM/ModuleToObject.cpp | 5 +-
mlir/lib/Target/LLVM/NVVM/Target.cpp | 52 ++++++++++++++-----
.../Target/LLVM/SerializeNVVMTarget.cpp | 32 ++++++++++++
6 files changed, 100 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 139360f8bd3fc..45eb19ac88d6c 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -58,7 +58,8 @@ class TargetOptions {
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<void(StringRef)> isaCallback = {},
+ function_ref<void(StringRef)> binaryCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -111,6 +112,9 @@ class TargetOptions {
/// for example PTX assembly.
function_ref<void(StringRef)> getISACallback() const;
+ /// Returns the callback invoked with the compilation target for the device.
+ function_ref<void(StringRef)> getBinaryCallback() const;
+
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -130,7 +134,8 @@ class TargetOptions {
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<void(StringRef)> isaCallback = {},
+ function_ref<void(StringRef)> binaryCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -167,6 +172,9 @@ class TargetOptions {
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;
+ /// Callback invoked with the compilation target for the device.
+ function_ref<void(StringRef)> binaryCallback;
+
private:
TypeID typeID;
};
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 11fea6f0a4443..99a698f5efb04 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -35,7 +35,8 @@ class ModuleToObject {
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<void(StringRef)> isaCallback = {},
+ function_ref<void(StringRef)> binaryCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -134,6 +135,9 @@ class ModuleToObject {
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;
+ /// Callback invoked with the compilation target for the device.
+ function_ref<void(StringRef)> binaryCallback;
+
private:
/// The TargetMachine created for the given Triple, if available.
/// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 61a630aa88960..1efa67008b091 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2655,12 +2655,13 @@ TargetOptions::TargetOptions(
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<void(StringRef)> isaCallback,
+ function_ref<void(StringRef)> binaryCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
- linkedLlvmIRCallback, optimizedLlvmIRCallback,
- isaCallback) {}
+ linkedLlvmIRCallback, optimizedLlvmIRCallback, isaCallback,
+ binaryCallback) {}
TargetOptions::TargetOptions(
TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
@@ -2670,7 +2671,8 @@ TargetOptions::TargetOptions(
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<void(StringRef)> isaCallback,
+ function_ref<void(StringRef)> binaryCallback)
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
@@ -2678,7 +2680,8 @@ TargetOptions::TargetOptions(
initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
optimizedLlvmIRCallback(optimizedLlvmIRCallback),
- isaCallback(isaCallback), typeID(typeID) {}
+ isaCallback(isaCallback), binaryCallback(binaryCallback), typeID(typeID) {
+}
TypeID TargetOptions::getTypeID() const { return typeID; }
@@ -2715,6 +2718,10 @@ function_ref<void(StringRef)> TargetOptions::getISACallback() const {
return isaCallback;
}
+function_ref<void(StringRef)> TargetOptions::getBinaryCallback() const {
+ return binaryCallback;
+}
+
CompilationTarget TargetOptions::getCompilationTarget() const {
return compilationTarget;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 4098ccc548dc1..0ef1ed9484337 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -39,12 +39,13 @@ ModuleToObject::ModuleToObject(
int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<void(StringRef)> isaCallback,
+ function_ref<void(StringRef)> binaryCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
optimizedLlvmIRCallback(optimizedLlvmIRCallback),
- isaCallback(isaCallback) {}
+ isaCallback(isaCallback), binaryCallback(binaryCallback) {}
ModuleToObject::~ModuleToObject() = default;
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 8760ea8588e2c..dcd0a8da950cf 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -24,6 +24,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/ADT/ScopeExit.h"
@@ -98,12 +99,12 @@ StringRef mlir::NVVM::getCUDAToolkitPath() {
SerializeGPUModuleBase::SerializeGPUModuleBase(
Operation &module, NVVMTargetAttr target,
const gpu::TargetOptions &targetOptions)
- : ModuleToObject(module, target.getTriple(), target.getChip(),
- target.getFeatures(), target.getO(),
- targetOptions.getInitialLlvmIRCallback(),
- targetOptions.getLinkedLlvmIRCallback(),
- targetOptions.getOptimizedLlvmIRCallback(),
- targetOptions.getISACallback()),
+ : ModuleToObject(
+ module, target.getTriple(), target.getChip(), target.getFeatures(),
+ target.getO(), targetOptions.getInitialLlvmIRCallback(),
+ targetOptions.getLinkedLlvmIRCallback(),
+ targetOptions.getOptimizedLlvmIRCallback(),
+ targetOptions.getISACallback(), targetOptions.getBinaryCallback()),
target(target), toolkitPath(targetOptions.getToolkitPath()),
librariesToLink(targetOptions.getLibrariesToLink()) {
@@ -213,11 +214,13 @@ class NVPTXSerializer : public SerializeGPUModuleBase {
/// Compiles PTX to cubin using `ptxas`.
std::optional<SmallVector<char, 0>>
- compileToBinary(const std::string &ptxCode);
+ compileToBinary(const std::string &ptxCode,
+ function_ref<void(StringRef)> binaryCallback);
/// Compiles PTX to cubin using the `nvptxcompiler` library.
std::optional<SmallVector<char, 0>>
- compileToBinaryNVPTX(const std::string &ptxCode);
+ compileToBinaryNVPTX(const std::string &ptxCode,
+ function_ref<void(StringRef)> binaryCallback);
/// Serializes the LLVM module to an object format, depending on the
/// compilation target selected in target options.
@@ -347,12 +350,12 @@ static void setOptionalCommandlineArguments(NVVMTargetAttr target,
// TODO: clean this method & have a generic tool driver or never emit binaries
// with this mechanism and let another stage take care of it.
std::optional<SmallVector<char, 0>>
-NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
+NVPTXSerializer::compileToBinary(const std::string &ptxCode,
+ function_ref<void(StringRef)> binaryCallback) {
// Determine if the serializer should create a fatbinary with the PTX embeded
// or a simple CUBIN binary.
const bool createFatbin =
targetOptions.getCompilationTarget() == gpu::CompilationTarget::Fatbin;
-
// Find the `ptxas` & `fatbinary` tools.
std::optional<std::string> ptxasCompiler = findTool("ptxas");
if (!ptxasCompiler)
@@ -521,6 +524,15 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
/*ErrMsg=*/&message))
return emitLogError("`fatbinary`");
+ if (binaryCallback) {
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
+ llvm::MemoryBuffer::getFile(logFile->first);
+ if (logBuffer && !(*logBuffer)->getBuffer().empty()) {
+ StringRef logRef = (*logBuffer)->getBuffer();
+ binaryCallback(logRef);
+ }
+ }
+
// Dump the output of the tools, helpful if the verbose flag was passed.
#define DEBUG_TYPE "serialize-to-binary"
LLVM_DEBUG({
@@ -569,8 +581,8 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
} \
} while (false)
-std::optional<SmallVector<char, 0>>
-NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
+std::optional<SmallVector<char, 0>> NVPTXSerializer::compileToBinaryNVPTX(
+ const std::string &ptxCode, function_ref<void(StringRef)> binaryCallback) {
Location loc = getOperation().getLoc();
nvPTXCompilerHandle compiler = nullptr;
nvPTXCompileResult status;
@@ -618,6 +630,18 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
RETURN_ON_NVPTXCOMPILER_ERROR(
nvPTXCompilerGetCompiledProgram(compiler, (void *)binary.data()));
+ if (binaryCallback) {
+ RETURN_ON_NVPTXCOMPILER_ERROR(
+ nvPTXCompilerGetInfoLogSize(compiler, &logSize));
+ if (logSize != 0) {
+ SmallVector<char> log(logSize + 1, 0);
+ RETURN_ON_NVPTXCOMPILER_ERROR(
+ nvPTXCompilerGetInfoLog(compiler, log.data()));
+ StringRef logRef(log.data(), log.size());
+ binaryCallback(logRef);
+ }
+ }
+
// Dump the log of the compiler, helpful if the verbose flag was passed.
#define DEBUG_TYPE "serialize-to-binary"
LLVM_DEBUG({
@@ -723,9 +747,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
moduleToObjectTimer.startTimer();
// Compile to binary.
#if MLIR_ENABLE_NVPTXCOMPILER
- result = compileToBinaryNVPTX(*serializedISA);
+ result = compileToBinaryNVPTX(*serializedISA, binaryCallback);
#else
- result = compileToBinary(*serializedISA);
+ result = compileToBinary(*serializedISA, binaryCallback);
#endif // MLIR_ENABLE_NVPTXCOMPILER
moduleToObjectTimer.stopTimer();
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index af0af89c7d07e..fca083c555202 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -220,6 +220,38 @@ TEST_F(MLIRTargetLLVMNVVM,
}
}
+// Test NVVM serialization to Binary.
+TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackForBinary)) {
+ if (!hasPtxas())
+ GTEST_SKIP() << "PTXAS compiler not found, skipping test.";
+
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+
+ // Create an NVVM target.
+ NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+ std::string binaryResult;
+ auto binaryCallback = [&binaryResult](llvm::StringRef binaryTarget) {
+ binaryResult = binaryTarget.str();
+ };
+ // Serialize the module.
+ auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+ ASSERT_TRUE(!!serializer);
+ gpu::TargetOptions options("", {}, "-v", "", gpu::CompilationTarget::Binary,
+ {}, {}, {}, {}, {}, binaryCallback);
+ for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+ std::optional<SmallVector<char, 0>> object =
+ serializer.serializeToObject(gpuModule, options);
+ // Check that the serializer was successful.
+ ASSERT_TRUE(object != std::nullopt);
+ ASSERT_TRUE(!object->empty());
+ ASSERT_TRUE(llvm::StringRef(binaryResult).contains("ptxas info"));
+ }
+}
+
// Test linking LLVM IR from a resource attribute.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
MLIRContext context(registry);
More information about the Mlir-commits
mailing list