[Mlir-commits] [mlir] [MLIR][NVVM] Add binaryCallback (PR #170853)

Guray Ozen llvmlistbot at llvm.org
Mon Dec 8 05:47:19 PST 2025


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/170853

>From cf1731aba1416ae31a025b8a4617601cad8d1b2b Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Mon, 8 Dec 2025 14:47:08 +0100
Subject: [PATCH] [MLIR][NVVM] Add binaryCallback

---
 .../Dialect/GPU/IR/CompilationInterfaces.h    | 12 ++++-
 .../include/mlir/Target/LLVM/ModuleToObject.h |  6 ++-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 17 ++++--
 mlir/lib/Target/LLVM/ModuleToObject.cpp       |  5 +-
 mlir/lib/Target/LLVM/NVVM/Target.cpp          | 52 ++++++++++++++-----
 .../Target/LLVM/SerializeNVVMTarget.cpp       | 32 ++++++++++++
 6 files changed, 100 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 139360f8bd3fc..45eb19ac88d6c 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -58,7 +58,8 @@ class TargetOptions {
       function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<void(StringRef)> isaCallback = {},
+      function_ref<void(StringRef)> binaryCallback = {});
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -111,6 +112,9 @@ class TargetOptions {
   /// for example PTX assembly.
   function_ref<void(StringRef)> getISACallback() const;
 
+  /// Returns the callback invoked with the compilation target for the device.
+  function_ref<void(StringRef)> getBinaryCallback() const;
+
   /// Returns the default compilation target: `CompilationTarget::Fatbin`.
   static CompilationTarget getDefaultCompilationTarget();
 
@@ -130,7 +134,8 @@ class TargetOptions {
       function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<void(StringRef)> isaCallback = {},
+      function_ref<void(StringRef)> binaryCallback = {});
 
   /// Path to the target toolkit.
   std::string toolkitPath;
@@ -167,6 +172,9 @@ class TargetOptions {
   /// for example PTX assembly.
   function_ref<void(StringRef)> isaCallback;
 
+  /// Callback invoked with the compilation target for the device.
+  function_ref<void(StringRef)> binaryCallback;
+
 private:
   TypeID typeID;
 };
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 11fea6f0a4443..99a698f5efb04 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -35,7 +35,8 @@ class ModuleToObject {
       function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<void(StringRef)> isaCallback = {},
+      function_ref<void(StringRef)> binaryCallback = {});
   virtual ~ModuleToObject();
 
   /// Returns the operation being serialized.
@@ -134,6 +135,9 @@ class ModuleToObject {
   /// for example PTX assembly.
   function_ref<void(StringRef)> isaCallback;
 
+  /// Callback invoked with the compilation target for the device.
+  function_ref<void(StringRef)> binaryCallback;
+
 private:
   /// The TargetMachine created for the given Triple, if available.
   /// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 61a630aa88960..1efa67008b091 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2655,12 +2655,13 @@ TargetOptions::TargetOptions(
     function_ref<void(llvm::Module &)> initialLlvmIRCallback,
     function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
     function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    function_ref<void(StringRef)> isaCallback,
+    function_ref<void(StringRef)> binaryCallback)
     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
                     cmdOptions, elfSection, compilationTarget,
                     getSymbolTableCallback, initialLlvmIRCallback,
-                    linkedLlvmIRCallback, optimizedLlvmIRCallback,
-                    isaCallback) {}
+                    linkedLlvmIRCallback, optimizedLlvmIRCallback, isaCallback,
+                    binaryCallback) {}
 
 TargetOptions::TargetOptions(
     TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
@@ -2670,7 +2671,8 @@ TargetOptions::TargetOptions(
     function_ref<void(llvm::Module &)> initialLlvmIRCallback,
     function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
     function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    function_ref<void(StringRef)> isaCallback,
+    function_ref<void(StringRef)> binaryCallback)
     : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
       cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
       compilationTarget(compilationTarget),
@@ -2678,7 +2680,8 @@ TargetOptions::TargetOptions(
       initialLlvmIRCallback(initialLlvmIRCallback),
       linkedLlvmIRCallback(linkedLlvmIRCallback),
       optimizedLlvmIRCallback(optimizedLlvmIRCallback),
-      isaCallback(isaCallback), typeID(typeID) {}
+      isaCallback(isaCallback), binaryCallback(binaryCallback), typeID(typeID) {
+}
 
 TypeID TargetOptions::getTypeID() const { return typeID; }
 
@@ -2715,6 +2718,10 @@ function_ref<void(StringRef)> TargetOptions::getISACallback() const {
   return isaCallback;
 }
 
+function_ref<void(StringRef)> TargetOptions::getBinaryCallback() const {
+  return binaryCallback;
+}
+
 CompilationTarget TargetOptions::getCompilationTarget() const {
   return compilationTarget;
 }
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 4098ccc548dc1..0ef1ed9484337 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -39,12 +39,13 @@ ModuleToObject::ModuleToObject(
     int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
     function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
     function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    function_ref<void(StringRef)> isaCallback,
+    function_ref<void(StringRef)> binaryCallback)
     : module(module), triple(triple), chip(chip), features(features),
       optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
       linkedLlvmIRCallback(linkedLlvmIRCallback),
       optimizedLlvmIRCallback(optimizedLlvmIRCallback),
-      isaCallback(isaCallback) {}
+      isaCallback(isaCallback), binaryCallback(binaryCallback) {}
 
 ModuleToObject::~ModuleToObject() = default;
 
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 8760ea8588e2c..dcd0a8da950cf 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/InterleavedRange.h"
 
 #include "llvm/ADT/ScopeExit.h"
@@ -98,12 +99,12 @@ StringRef mlir::NVVM::getCUDAToolkitPath() {
 SerializeGPUModuleBase::SerializeGPUModuleBase(
     Operation &module, NVVMTargetAttr target,
     const gpu::TargetOptions &targetOptions)
-    : ModuleToObject(module, target.getTriple(), target.getChip(),
-                     target.getFeatures(), target.getO(),
-                     targetOptions.getInitialLlvmIRCallback(),
-                     targetOptions.getLinkedLlvmIRCallback(),
-                     targetOptions.getOptimizedLlvmIRCallback(),
-                     targetOptions.getISACallback()),
+    : ModuleToObject(
+          module, target.getTriple(), target.getChip(), target.getFeatures(),
+          target.getO(), targetOptions.getInitialLlvmIRCallback(),
+          targetOptions.getLinkedLlvmIRCallback(),
+          targetOptions.getOptimizedLlvmIRCallback(),
+          targetOptions.getISACallback(), targetOptions.getBinaryCallback()),
       target(target), toolkitPath(targetOptions.getToolkitPath()),
       librariesToLink(targetOptions.getLibrariesToLink()) {
 
@@ -213,11 +214,13 @@ class NVPTXSerializer : public SerializeGPUModuleBase {
 
   /// Compiles PTX to cubin using `ptxas`.
   std::optional<SmallVector<char, 0>>
-  compileToBinary(const std::string &ptxCode);
+  compileToBinary(const std::string &ptxCode,
+                  function_ref<void(StringRef)> binaryCallback);
 
   /// Compiles PTX to cubin using the `nvptxcompiler` library.
   std::optional<SmallVector<char, 0>>
-  compileToBinaryNVPTX(const std::string &ptxCode);
+  compileToBinaryNVPTX(const std::string &ptxCode,
+                       function_ref<void(StringRef)> binaryCallback);
 
   /// Serializes the LLVM module to an object format, depending on the
   /// compilation target selected in target options.
@@ -347,12 +350,12 @@ static void setOptionalCommandlineArguments(NVVMTargetAttr target,
 // TODO: clean this method & have a generic tool driver or never emit binaries
 // with this mechanism and let another stage take care of it.
 std::optional<SmallVector<char, 0>>
-NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
+NVPTXSerializer::compileToBinary(const std::string &ptxCode,
+                                 function_ref<void(StringRef)> binaryCallback) {
   // Determine if the serializer should create a fatbinary with the PTX embeded
   // or a simple CUBIN binary.
   const bool createFatbin =
       targetOptions.getCompilationTarget() == gpu::CompilationTarget::Fatbin;
-
   // Find the `ptxas` & `fatbinary` tools.
   std::optional<std::string> ptxasCompiler = findTool("ptxas");
   if (!ptxasCompiler)
@@ -521,6 +524,15 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
                                                 /*ErrMsg=*/&message))
     return emitLogError("`fatbinary`");
 
+  if (binaryCallback) {
+    llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
+        llvm::MemoryBuffer::getFile(logFile->first);
+    if (logBuffer && !(*logBuffer)->getBuffer().empty()) {
+      StringRef logRef = (*logBuffer)->getBuffer();
+      binaryCallback(logRef);
+    }
+  }
+
 // Dump the output of the tools, helpful if the verbose flag was passed.
 #define DEBUG_TYPE "serialize-to-binary"
   LLVM_DEBUG({
@@ -569,8 +581,8 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
     }                                                                          \
   } while (false)
 
-std::optional<SmallVector<char, 0>>
-NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
+std::optional<SmallVector<char, 0>> NVPTXSerializer::compileToBinaryNVPTX(
+    const std::string &ptxCode, function_ref<void(StringRef)> binaryCallback) {
   Location loc = getOperation().getLoc();
   nvPTXCompilerHandle compiler = nullptr;
   nvPTXCompileResult status;
@@ -618,6 +630,18 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
   RETURN_ON_NVPTXCOMPILER_ERROR(
       nvPTXCompilerGetCompiledProgram(compiler, (void *)binary.data()));
 
+  if (binaryCallback) {
+    RETURN_ON_NVPTXCOMPILER_ERROR(
+        nvPTXCompilerGetInfoLogSize(compiler, &logSize));
+    if (logSize != 0) {
+      SmallVector<char> log(logSize + 1, 0);
+      RETURN_ON_NVPTXCOMPILER_ERROR(
+          nvPTXCompilerGetInfoLog(compiler, log.data()));
+      StringRef logRef(log.data(), log.size());
+      binaryCallback(logRef);
+    }
+  }
+
 // Dump the log of the compiler, helpful if the verbose flag was passed.
 #define DEBUG_TYPE "serialize-to-binary"
   LLVM_DEBUG({
@@ -723,9 +747,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
   moduleToObjectTimer.startTimer();
   // Compile to binary.
 #if MLIR_ENABLE_NVPTXCOMPILER
-  result = compileToBinaryNVPTX(*serializedISA);
+  result = compileToBinaryNVPTX(*serializedISA, binaryCallback);
 #else
-  result = compileToBinary(*serializedISA);
+  result = compileToBinary(*serializedISA, binaryCallback);
 #endif // MLIR_ENABLE_NVPTXCOMPILER
 
   moduleToObjectTimer.stopTimer();
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index af0af89c7d07e..fca083c555202 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -220,6 +220,38 @@ TEST_F(MLIRTargetLLVMNVVM,
   }
 }
 
+// Test NVVM serialization to Binary.
+TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackForBinary)) {
+  if (!hasPtxas())
+    GTEST_SKIP() << "PTXAS compiler not found, skipping test.";
+
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+
+  // Create an NVVM target.
+  NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+  std::string binaryResult;
+  auto binaryCallback = [&binaryResult](llvm::StringRef binaryTarget) {
+    binaryResult = binaryTarget.str();
+  };
+  // Serialize the module.
+  auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+  ASSERT_TRUE(!!serializer);
+  gpu::TargetOptions options("", {}, "-v", "", gpu::CompilationTarget::Binary,
+                             {}, {}, {}, {}, {}, binaryCallback);
+  for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+    std::optional<SmallVector<char, 0>> object =
+        serializer.serializeToObject(gpuModule, options);
+    // Check that the serializer was successful.
+    ASSERT_TRUE(object != std::nullopt);
+    ASSERT_TRUE(!object->empty());
+    ASSERT_TRUE(llvm::StringRef(binaryResult).contains("ptxas info"));
+  }
+}
+
 // Test linking LLVM IR from a resource attribute.
 TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
   MLIRContext context(registry);



More information about the Mlir-commits mailing list