[Mlir-commits] [mlir] [MLIR][NVVM] Add binaryCallback (PR #170853)

Guray Ozen llvmlistbot at llvm.org
Tue Dec 9 01:08:21 PST 2025


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/170853

>From d05967081850a7caa86eaf71eb3c8b40cbbcee37 Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Tue, 9 Dec 2025 10:08:10 +0100
Subject: [PATCH] [MLIR][NVVM] Add binaryCallback

---
 .../Dialect/GPU/IR/CompilationInterfaces.h    | 12 ++++-
 .../include/mlir/Target/LLVM/ModuleToObject.h |  6 ++-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 19 +++++--
 mlir/lib/Target/LLVM/ModuleToObject.cpp       |  6 ++-
 mlir/lib/Target/LLVM/NVVM/Target.cpp          | 51 ++++++++++++++-----
 .../Target/LLVM/SerializeNVVMTarget.cpp       | 34 +++++++++++++
 6 files changed, 106 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 139360f8bd3fc..0e5111583f456 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -58,7 +58,8 @@ class TargetOptions {
       function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<void(StringRef)> isaCallback = {},
+      function_ref<void(StringRef)> binaryCompilerDiagnosticCallback = {});
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -111,6 +112,9 @@ class TargetOptions {
   /// for example PTX assembly.
   function_ref<void(StringRef)> getISACallback() const;
 
+  /// Returns the callback invoked with the compilation target for the device.
+  function_ref<void(StringRef)> getbinaryCompilerDiagnosticCallback() const;
+
   /// Returns the default compilation target: `CompilationTarget::Fatbin`.
   static CompilationTarget getDefaultCompilationTarget();
 
@@ -130,7 +134,8 @@ class TargetOptions {
       function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<void(StringRef)> isaCallback = {},
+      function_ref<void(StringRef)> binaryCompilerDiagnosticCallback = {});
 
   /// Path to the target toolkit.
   std::string toolkitPath;
@@ -167,6 +172,9 @@ class TargetOptions {
   /// for example PTX assembly.
   function_ref<void(StringRef)> isaCallback;
 
+  /// Callback invoked with the compilation target for the device.
+  function_ref<void(StringRef)> binaryCompilerDiagnosticCallback;
+
 private:
   TypeID typeID;
 };
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 11fea6f0a4443..4c23feeccc75c 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -35,7 +35,8 @@ class ModuleToObject {
       function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
       function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<void(StringRef)> isaCallback = {},
+      function_ref<void(StringRef)> binaryCompilerDiagnosticCallback = {});
   virtual ~ModuleToObject();
 
   /// Returns the operation being serialized.
@@ -134,6 +135,9 @@ class ModuleToObject {
   /// for example PTX assembly.
   function_ref<void(StringRef)> isaCallback;
 
+  /// Callback for diagnostic messages from the binary compiler.
+  function_ref<void(StringRef)> binaryCompilerDiagnosticCallback;
+
 private:
   /// The TargetMachine created for the given Triple, if available.
   /// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 61a630aa88960..7734a3e63842d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2655,12 +2655,13 @@ TargetOptions::TargetOptions(
     function_ref<void(llvm::Module &)> initialLlvmIRCallback,
     function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
     function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    function_ref<void(StringRef)> isaCallback,
+    function_ref<void(StringRef)> binaryCompilerDiagnosticCallback)
     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
                     cmdOptions, elfSection, compilationTarget,
                     getSymbolTableCallback, initialLlvmIRCallback,
-                    linkedLlvmIRCallback, optimizedLlvmIRCallback,
-                    isaCallback) {}
+                    linkedLlvmIRCallback, optimizedLlvmIRCallback, isaCallback,
+                    binaryCompilerDiagnosticCallback) {}
 
 TargetOptions::TargetOptions(
     TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
@@ -2670,7 +2671,8 @@ TargetOptions::TargetOptions(
     function_ref<void(llvm::Module &)> initialLlvmIRCallback,
     function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
     function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    function_ref<void(StringRef)> isaCallback,
+    function_ref<void(StringRef)> binaryCompilerDiagnosticCallback)
     : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
       cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
       compilationTarget(compilationTarget),
@@ -2678,7 +2680,9 @@ TargetOptions::TargetOptions(
       initialLlvmIRCallback(initialLlvmIRCallback),
       linkedLlvmIRCallback(linkedLlvmIRCallback),
       optimizedLlvmIRCallback(optimizedLlvmIRCallback),
-      isaCallback(isaCallback), typeID(typeID) {}
+      isaCallback(isaCallback),
+      binaryCompilerDiagnosticCallback(binaryCompilerDiagnosticCallback),
+      typeID(typeID) {}
 
 TypeID TargetOptions::getTypeID() const { return typeID; }
 
@@ -2715,6 +2719,11 @@ function_ref<void(StringRef)> TargetOptions::getISACallback() const {
   return isaCallback;
 }
 
+function_ref<void(StringRef)>
+TargetOptions::getbinaryCompilerDiagnosticCallback() const {
+  return binaryCompilerDiagnosticCallback;
+}
+
 CompilationTarget TargetOptions::getCompilationTarget() const {
   return compilationTarget;
 }
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 4098ccc548dc1..537643035e4cf 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -39,12 +39,14 @@ ModuleToObject::ModuleToObject(
     int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
     function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
     function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    function_ref<void(StringRef)> isaCallback,
+    function_ref<void(StringRef)> binaryCompilerDiagnosticCallback)
     : module(module), triple(triple), chip(chip), features(features),
       optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
       linkedLlvmIRCallback(linkedLlvmIRCallback),
       optimizedLlvmIRCallback(optimizedLlvmIRCallback),
-      isaCallback(isaCallback) {}
+      isaCallback(isaCallback),
+      binaryCompilerDiagnosticCallback(binaryCompilerDiagnosticCallback) {}
 
 ModuleToObject::~ModuleToObject() = default;
 
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 8760ea8588e2c..2049dea3ea9e0 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/InterleavedRange.h"
 
 #include "llvm/ADT/ScopeExit.h"
@@ -103,7 +104,8 @@ SerializeGPUModuleBase::SerializeGPUModuleBase(
                      targetOptions.getInitialLlvmIRCallback(),
                      targetOptions.getLinkedLlvmIRCallback(),
                      targetOptions.getOptimizedLlvmIRCallback(),
-                     targetOptions.getISACallback()),
+                     targetOptions.getISACallback(),
+                     targetOptions.getbinaryCompilerDiagnosticCallback()),
       target(target), toolkitPath(targetOptions.getToolkitPath()),
       librariesToLink(targetOptions.getLibrariesToLink()) {
 
@@ -212,12 +214,14 @@ class NVPTXSerializer : public SerializeGPUModuleBase {
   gpu::GPUModuleOp getOperation();
 
   /// Compiles PTX to cubin using `ptxas`.
-  std::optional<SmallVector<char, 0>>
-  compileToBinary(const std::string &ptxCode);
+  std::optional<SmallVector<char, 0>> compileToBinary(
+      const std::string &ptxCode,
+      function_ref<void(StringRef)> binaryCompilerDiagnosticCallback);
 
   /// Compiles PTX to cubin using the `nvptxcompiler` library.
-  std::optional<SmallVector<char, 0>>
-  compileToBinaryNVPTX(const std::string &ptxCode);
+  std::optional<SmallVector<char, 0>> compileToBinaryNVPTX(
+      const std::string &ptxCode,
+      function_ref<void(StringRef)> binaryCompilerDiagnosticCallback);
 
   /// Serializes the LLVM module to an object format, depending on the
   /// compilation target selected in target options.
@@ -346,13 +350,13 @@ static void setOptionalCommandlineArguments(NVVMTargetAttr target,
 
 // TODO: clean this method & have a generic tool driver or never emit binaries
 // with this mechanism and let another stage take care of it.
-std::optional<SmallVector<char, 0>>
-NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
+std::optional<SmallVector<char, 0>> NVPTXSerializer::compileToBinary(
+    const std::string &ptxCode,
+    function_ref<void(StringRef)> binaryCompilerDiagnosticCallback) {
   // Determine if the serializer should create a fatbinary with the PTX embeded
   // or a simple CUBIN binary.
   const bool createFatbin =
       targetOptions.getCompilationTarget() == gpu::CompilationTarget::Fatbin;
-
   // Find the `ptxas` & `fatbinary` tools.
   std::optional<std::string> ptxasCompiler = findTool("ptxas");
   if (!ptxasCompiler)
@@ -521,6 +525,15 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
                                                 /*ErrMsg=*/&message))
     return emitLogError("`fatbinary`");
 
+  if (binaryCompilerDiagnosticCallback) {
+    llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
+        llvm::MemoryBuffer::getFile(logFile->first);
+    if (logBuffer && !(*logBuffer)->getBuffer().empty()) {
+      StringRef logRef = (*logBuffer)->getBuffer();
+      binaryCompilerDiagnosticCallback(logRef);
+    }
+  }
+
 // Dump the output of the tools, helpful if the verbose flag was passed.
 #define DEBUG_TYPE "serialize-to-binary"
   LLVM_DEBUG({
@@ -569,8 +582,9 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
     }                                                                          \
   } while (false)
 
-std::optional<SmallVector<char, 0>>
-NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
+std::optional<SmallVector<char, 0>> NVPTXSerializer::compileToBinaryNVPTX(
+    const std::string &ptxCode,
+    function_ref<void(StringRef)> binaryCompilerDiagnosticCallback) {
   Location loc = getOperation().getLoc();
   nvPTXCompilerHandle compiler = nullptr;
   nvPTXCompileResult status;
@@ -618,6 +632,18 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
   RETURN_ON_NVPTXCOMPILER_ERROR(
       nvPTXCompilerGetCompiledProgram(compiler, (void *)binary.data()));
 
+  if (binaryCompilerDiagnosticCallback) {
+    RETURN_ON_NVPTXCOMPILER_ERROR(
+        nvPTXCompilerGetInfoLogSize(compiler, &logSize));
+    if (logSize != 0) {
+      SmallVector<char> log(logSize + 1, 0);
+      RETURN_ON_NVPTXCOMPILER_ERROR(
+          nvPTXCompilerGetInfoLog(compiler, log.data()));
+      StringRef logRef(log.data(), log.size());
+      binaryCompilerDiagnosticCallback(logRef);
+    }
+  }
+
 // Dump the log of the compiler, helpful if the verbose flag was passed.
 #define DEBUG_TYPE "serialize-to-binary"
   LLVM_DEBUG({
@@ -723,9 +749,10 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
   moduleToObjectTimer.startTimer();
   // Compile to binary.
 #if MLIR_ENABLE_NVPTXCOMPILER
-  result = compileToBinaryNVPTX(*serializedISA);
+  result =
+      compileToBinaryNVPTX(*serializedISA, binaryCompilerDiagnosticCallback);
 #else
-  result = compileToBinary(*serializedISA);
+  result = compileToBinary(*serializedISA, binaryCompilerDiagnosticCallback);
 #endif // MLIR_ENABLE_NVPTXCOMPILER
 
   moduleToObjectTimer.stopTimer();
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index af0af89c7d07e..f31b8a7c0619a 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -220,6 +220,40 @@ TEST_F(MLIRTargetLLVMNVVM,
   }
 }
 
+// Test NVVM serialization to Binary.
+TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackForBinary)) {
+  if (!hasPtxas())
+    GTEST_SKIP() << "PTXAS compiler not found, skipping test.";
+
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+
+  // Create an NVVM target.
+  NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+  std::string binaryResult;
+  auto binaryCompilerDiagnosticCallback =
+      [&binaryResult](llvm::StringRef binaryTarget) {
+        binaryResult = binaryTarget.str();
+      };
+  // Serialize the module.
+  auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+  ASSERT_TRUE(!!serializer);
+  gpu::TargetOptions options("", {}, "-v", "", gpu::CompilationTarget::Binary,
+                             {}, {}, {}, {}, {},
+                             binaryCompilerDiagnosticCallback);
+  for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+    std::optional<SmallVector<char, 0>> object =
+        serializer.serializeToObject(gpuModule, options);
+    // Check that the serializer was successful.
+    ASSERT_TRUE(object != std::nullopt);
+    ASSERT_TRUE(!object->empty());
+    ASSERT_TRUE(llvm::StringRef(binaryResult).contains("ptxas info"));
+  }
+}
+
 // Test linking LLVM IR from a resource attribute.
 TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
   MLIRContext context(registry);



More information about the Mlir-commits mailing list