[Mlir-commits] [mlir] [mlir][gpu] Propagate errors from `ModuleToObject` callbacks (PR #170134)
Ivan Butygin
llvmlistbot at llvm.org
Tue Dec 2 14:04:48 PST 2025
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/170134
>From 909d916eaf8c3f2677e4529d4875b04e28b9c4d9 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Dec 2025 13:54:07 +0100
Subject: [PATCH 1/3] [mlir][gpu] Propagate errors from `ModuleToObject`
callbacks
---
.../Dialect/GPU/IR/CompilationInterfaces.h | 33 ++++----
.../include/mlir/Target/LLVM/ModuleToObject.h | 16 ++--
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 24 +++---
mlir/lib/Target/LLVM/ModuleToObject.cpp | 33 +++++---
mlir/lib/Target/LLVM/NVVM/Target.cpp | 8 +-
.../Target/LLVM/SerializeNVVMTarget.cpp | 45 ++++++++--
.../Target/LLVM/SerializeToLLVMBitcode.cpp | 83 ++++++++++++++++++-
7 files changed, 186 insertions(+), 56 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 139360f8bd3fc..00f885898ffa1 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -55,10 +55,10 @@ class TargetOptions {
StringRef cmdOptions = {}, StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
- function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(StringRef)> isaCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -97,19 +97,20 @@ class TargetOptions {
/// Returns the callback invoked with the initial LLVM IR for the device
/// module.
- function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+ function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const;
/// Returns the callback invoked with LLVM IR for the device module
/// after linking the device libraries.
- function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+ function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const;
/// Returns the callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+ function_ref<LogicalResult(llvm::Module &)>
+ getOptimizedLlvmIRCallback() const;
/// Returns the callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<void(StringRef)> getISACallback() const;
+ function_ref<LogicalResult(StringRef)> getISACallback() const;
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -127,10 +128,10 @@ class TargetOptions {
StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
- function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(StringRef)> isaCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -153,19 +154,19 @@ class TargetOptions {
function_ref<SymbolTable *()> getSymbolTableCallback;
/// Callback invoked with the initial LLVM IR for the device module.
- function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<void(StringRef)> isaCallback;
+ function_ref<LogicalResult(StringRef)> isaCallback;
private:
TypeID typeID;
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 11fea6f0a4443..0edc20cd32620 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -32,10 +32,10 @@ class ModuleToObject {
ModuleToObject(
Operation &module, StringRef triple, StringRef chip,
StringRef features = {}, int optLevel = 3,
- function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(StringRef)> isaCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -120,19 +120,19 @@ class ModuleToObject {
int optLevel;
/// Callback invoked with the initial LLVM IR for the device module.
- function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<void(StringRef)> isaCallback;
+ function_ref<LogicalResult(StringRef)> isaCallback;
private:
/// The TargetMachine created for the given Triple, if available.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 6c6d8d2bad55d..a813608fdf209 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
- function_ref<void(llvm::Module &)> initialLlvmIRCallback,
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<LogicalResult(StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
@@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
- function_ref<void(llvm::Module &)> initialLlvmIRCallback,
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<LogicalResult(StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
@@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getInitialLlvmIRCallback() const {
return initialLlvmIRCallback;
}
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getLinkedLlvmIRCallback() const {
return linkedLlvmIRCallback;
}
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getOptimizedLlvmIRCallback() const {
return optimizedLlvmIRCallback;
}
-function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const {
return isaCallback;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 4098ccc548dc1..d881dda69453b 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -36,10 +36,11 @@ using namespace mlir::LLVM;
ModuleToObject::ModuleToObject(
Operation &module, StringRef triple, StringRef chip, StringRef features,
- int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ int optLevel,
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<LogicalResult(StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
@@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);
- if (initialLlvmIRCallback)
- initialLlvmIRCallback(*llvmModule);
+ if (initialLlvmIRCallback) {
+ if (failed(initialLlvmIRCallback(*llvmModule))) {
+ getOperation().emitError() << "InitialLLVMIRCallback failed.";
+ return std::nullopt;
+ }
+ }
// Link bitcode files.
handleModulePreLink(*llvmModule);
@@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
handleModulePostLink(*llvmModule);
}
- if (linkedLlvmIRCallback)
- linkedLlvmIRCallback(*llvmModule);
+ if (linkedLlvmIRCallback) {
+ if (failed(linkedLlvmIRCallback(*llvmModule))) {
+ getOperation().emitError() << "LinkedLLVMIRCallback failed.";
+ return std::nullopt;
+ }
+ }
// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;
- if (optimizedLlvmIRCallback)
- optimizedLlvmIRCallback(*llvmModule);
+ if (optimizedLlvmIRCallback) {
+ if (failed(optimizedLlvmIRCallback(*llvmModule))) {
+ getOperation().emitError() << "OptimizedLLVMIRCallback failed.";
+ return std::nullopt;
+ }
+ }
// Return the serialized object.
return moduleToObject(*llvmModule);
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 8760ea8588e2c..cbd6a6d878813 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
return std::nullopt;
}
- if (isaCallback)
- isaCallback(serializedISA.value());
+ if (isaCallback) {
+ if (failed(isaCallback(serializedISA.value()))) {
+ getOperation().emitError() << "ISACallback failed.";
+ return std::nullopt;
+ }
+ }
#define DEBUG_TYPE "serialize-to-isa"
LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index af0af89c7d07e..1692c4490e4d1 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM,
ASSERT_TRUE(!!serializer);
std::string initialLLVMIR;
- auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+ auto initialCallback =
+ [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
+ return success();
};
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
std::string optimizedLLVMIR;
- auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+ auto optimizedCallback =
+ [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
std::string isaResult;
- auto isaCallback = [&isaResult](llvm::StringRef isa) {
+ auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult {
isaResult = isa.str();
+ return success();
};
gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
@@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM,
}
}
+// Test callback functions failure with ISA.
+TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+
+ NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+
+ auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+ ASSERT_TRUE(!!serializer);
+
+ auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
+ {}, {}, {}, {}, isaCallback);
+
+ for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+ std::optional<SmallVector<char, 0>> object =
+ serializer.serializeToObject(gpuModule, options);
+
+ ASSERT_TRUE(object == std::nullopt);
+ }
+}
+
// Test linking LLVM IR from a resource attribute.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
MLIRContext context(registry);
@@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
// Hook to intercept the LLVM IR after linking external libs.
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
// Store the bitcode as a DenseI8ArrayAttr.
diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
index 3c880edee4ffc..b392065132787 100644
--- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
@@ -168,9 +168,11 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string initialLLVMIR;
- auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+ auto initialCallback =
+ [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
+ return success();
};
gpu::TargetOptions opts(
@@ -196,9 +198,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
gpu::TargetOptions opts(
@@ -225,9 +228,11 @@ TEST_F(MLIRTargetLLVM,
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string optimizedLLVMIR;
- auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+ auto optimizedCallback =
+ [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
gpu::TargetOptions opts(
@@ -240,3 +245,75 @@ TEST_F(MLIRTargetLLVM,
ASSERT_TRUE(!serializedBinary->empty());
ASSERT_TRUE(!optimizedLLVMIR.empty());
}
+
+// Test callback function failure with initial LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+ {}, initialCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary == std::nullopt);
+}
+
+// Test callback function failure with linked LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+ {}, {}, linkedCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary == std::nullopt);
+}
+
+// Test callback function failure with optimized LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+ {}, {}, {}, optimizedCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary == std::nullopt);
+}
>From ce94f04b031c627455eb661c95faccbc1e3237fc Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Dec 2025 17:24:22 +0100
Subject: [PATCH 2/3] pass op to the callback
---
.../Dialect/GPU/IR/CompilationInterfaces.h | 43 ++++++++++++-------
.../include/mlir/Target/LLVM/ModuleToObject.h | 22 ++++++----
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 31 +++++++------
mlir/lib/Target/LLVM/ModuleToObject.cpp | 32 ++++++--------
mlir/lib/Target/LLVM/NVVM/Target.cpp | 7 +--
.../Target/LLVM/SerializeNVVMTarget.cpp | 18 +++++---
.../Target/LLVM/SerializeToLLVMBitcode.cpp | 18 +++++---
7 files changed, 99 insertions(+), 72 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 00f885898ffa1..1ca25d47d2d5f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -55,10 +55,13 @@ class TargetOptions {
StringRef cmdOptions = {}, StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
- function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<LogicalResult(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ initialLlvmIRCallback = {},
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -97,20 +100,22 @@ class TargetOptions {
/// Returns the callback invoked with the initial LLVM IR for the device
/// module.
- function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const;
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ getInitialLlvmIRCallback() const;
/// Returns the callback invoked with LLVM IR for the device module
/// after linking the device libraries.
- function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const;
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ getLinkedLlvmIRCallback() const;
/// Returns the callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<LogicalResult(llvm::Module &)>
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
getOptimizedLlvmIRCallback() const;
/// Returns the callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<LogicalResult(StringRef)> getISACallback() const;
+ function_ref<LogicalResult(Operation *op, StringRef)> getISACallback() const;
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -128,10 +133,13 @@ class TargetOptions {
StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
- function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<LogicalResult(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ initialLlvmIRCallback = {},
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -154,19 +162,22 @@ class TargetOptions {
function_ref<SymbolTable *()> getSymbolTableCallback;
/// Callback invoked with the initial LLVM IR for the device module.
- function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ initialLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
- function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ linkedLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ optimizedLlvmIRCallback;
/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<LogicalResult(StringRef)> isaCallback;
+ function_ref<LogicalResult(Operation *op, StringRef)> isaCallback;
private:
TypeID typeID;
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 0edc20cd32620..986b210ea7765 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -32,10 +32,13 @@ class ModuleToObject {
ModuleToObject(
Operation &module, StringRef triple, StringRef chip,
StringRef features = {}, int optLevel = 3,
- function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<LogicalResult(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ initialLlvmIRCallback = {},
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -120,19 +123,22 @@ class ModuleToObject {
int optLevel;
/// Callback invoked with the initial LLVM IR for the device module.
- function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ initialLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
- function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ linkedLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ optimizedLlvmIRCallback;
/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<LogicalResult(StringRef)> isaCallback;
+ function_ref<LogicalResult(Operation *op, StringRef)> isaCallback;
private:
/// The TargetMachine created for the given Triple, if available.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index a813608fdf209..c188517ee4155 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2652,10 +2652,13 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
- function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
- function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<LogicalResult(StringRef)> isaCallback)
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ initialLlvmIRCallback,
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ linkedLlvmIRCallback,
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ optimizedLlvmIRCallback,
+ function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
@@ -2667,10 +2670,13 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
- function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
- function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<LogicalResult(StringRef)> isaCallback)
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ initialLlvmIRCallback,
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ linkedLlvmIRCallback,
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ optimizedLlvmIRCallback,
+ function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
@@ -2696,22 +2702,23 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}
-function_ref<LogicalResult(llvm::Module &)>
+function_ref<LogicalResult(Operation *op, llvm::Module &)>
TargetOptions::getInitialLlvmIRCallback() const {
return initialLlvmIRCallback;
}
-function_ref<LogicalResult(llvm::Module &)>
+function_ref<LogicalResult(Operation *op, llvm::Module &)>
TargetOptions::getLinkedLlvmIRCallback() const {
return linkedLlvmIRCallback;
}
-function_ref<LogicalResult(llvm::Module &)>
+function_ref<LogicalResult(Operation *op, llvm::Module &)>
TargetOptions::getOptimizedLlvmIRCallback() const {
return optimizedLlvmIRCallback;
}
-function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const {
+function_ref<LogicalResult(Operation *op, StringRef)>
+TargetOptions::getISACallback() const {
return isaCallback;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index d881dda69453b..60d823ccb0d14 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -37,10 +37,13 @@ using namespace mlir::LLVM;
ModuleToObject::ModuleToObject(
Operation &module, StringRef triple, StringRef chip, StringRef features,
int optLevel,
- function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
- function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<LogicalResult(StringRef)> isaCallback)
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ initialLlvmIRCallback,
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ linkedLlvmIRCallback,
+ function_ref<LogicalResult(Operation *op, llvm::Module &)>
+ optimizedLlvmIRCallback,
+ function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
@@ -255,12 +258,9 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);
- if (initialLlvmIRCallback) {
- if (failed(initialLlvmIRCallback(*llvmModule))) {
- getOperation().emitError() << "InitialLLVMIRCallback failed.";
+ if (initialLlvmIRCallback)
+ if (failed(initialLlvmIRCallback(&getOperation(), *llvmModule)))
return std::nullopt;
- }
- }
// Link bitcode files.
handleModulePreLink(*llvmModule);
@@ -274,23 +274,17 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
handleModulePostLink(*llvmModule);
}
- if (linkedLlvmIRCallback) {
- if (failed(linkedLlvmIRCallback(*llvmModule))) {
- getOperation().emitError() << "LinkedLLVMIRCallback failed.";
+ if (linkedLlvmIRCallback)
+ if (failed(linkedLlvmIRCallback(&getOperation(), *llvmModule)))
return std::nullopt;
- }
- }
// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;
- if (optimizedLlvmIRCallback) {
- if (failed(optimizedLlvmIRCallback(*llvmModule))) {
- getOperation().emitError() << "OptimizedLLVMIRCallback failed.";
+ if (optimizedLlvmIRCallback)
+ if (failed(optimizedLlvmIRCallback(&getOperation(), *llvmModule)))
return std::nullopt;
- }
- }
// Return the serialized object.
return moduleToObject(*llvmModule);
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index cbd6a6d878813..5802f98af8cac 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -707,12 +707,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
return std::nullopt;
}
- if (isaCallback) {
- if (failed(isaCallback(serializedISA.value()))) {
- getOperation().emitError() << "ISACallback failed.";
+ if (isaCallback)
+ if (failed(isaCallback(getOperation(), serializedISA.value())))
return std::nullopt;
- }
- }
#define DEBUG_TYPE "serialize-to-isa"
LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index 1692c4490e4d1..fdf8c3c72cfba 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -177,14 +177,16 @@ TEST_F(MLIRTargetLLVMNVVM,
std::string initialLLVMIR;
auto initialCallback =
- [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
+ [&initialLLVMIR](Operation * /*op*/,
+ llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
return success();
};
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
+ auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
+ llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
@@ -192,14 +194,16 @@ TEST_F(MLIRTargetLLVMNVVM,
std::string optimizedLLVMIR;
auto optimizedCallback =
- [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
+ [&optimizedLLVMIR](Operation * /*op*/,
+ llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
return success();
};
std::string isaResult;
- auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult {
+ auto isaCallback = [&isaResult](Operation * /*op*/,
+ llvm::StringRef isa) -> LogicalResult {
isaResult = isa.str();
return success();
};
@@ -239,7 +243,8 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
ASSERT_TRUE(!!serializer);
- auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult {
+ auto isaCallback = [](Operation * /*op*/,
+ llvm::StringRef /*isa*/) -> LogicalResult {
return failure();
};
@@ -295,7 +300,8 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
// Hook to intercept the LLVM IR after linking external libs.
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
+ auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
+ llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
index b392065132787..97457817c32cd 100644
--- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
@@ -169,7 +169,8 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
std::string initialLLVMIR;
auto initialCallback =
- [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
+ [&initialLLVMIR](Operation * /*op*/,
+ llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
return success();
@@ -198,7 +199,8 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
+ auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
+ llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
@@ -229,7 +231,8 @@ TEST_F(MLIRTargetLLVM,
std::string optimizedLLVMIR;
auto optimizedCallback =
- [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
+ [&optimizedLLVMIR](Operation * /*op*/,
+ llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
return success();
@@ -257,7 +260,8 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) {
IntegerAttr target = builder.getI32IntegerAttr(0);
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
- auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ auto initialCallback = [](Operation * /*op*/,
+ llvm::Module & /*module*/) -> LogicalResult {
return failure();
};
@@ -281,7 +285,8 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) {
IntegerAttr target = builder.getI32IntegerAttr(0);
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
- auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ auto linkedCallback = [](Operation * /*op*/,
+ llvm::Module & /*module*/) -> LogicalResult {
return failure();
};
@@ -305,7 +310,8 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) {
IntegerAttr target = builder.getI32IntegerAttr(0);
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
- auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ auto optimizedCallback = [](Operation * /*op*/,
+ llvm::Module & /*module*/) -> LogicalResult {
return failure();
};
>From 6ac6dbce74f602e54036413f359409554b8c535e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 2 Dec 2025 22:37:37 +0100
Subject: [PATCH 3/3] use InFlightDiagnostic
---
.../Dialect/GPU/IR/CompilationInterfaces.h | 49 ++++++++-----------
.../include/mlir/Target/LLVM/ModuleToObject.h | 32 ++++++------
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 29 +++--------
mlir/lib/Target/LLVM/ModuleToObject.cpp | 26 +++++-----
mlir/lib/Target/LLVM/NVVM/Target.cpp | 6 ++-
.../Target/LLVM/SerializeNVVMTarget.cpp | 33 ++++++++-----
.../Target/LLVM/SerializeToLLVMBitcode.cpp | 36 ++++++++------
7 files changed, 103 insertions(+), 108 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 1ca25d47d2d5f..e5eb043dc36e1 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -46,6 +46,11 @@ class OffloadingTranslationAttrTrait
/// ensure type safeness. Targets are free to ignore these options.
class TargetOptions {
public:
+ using DiagnosticCallback = function_ref<InFlightDiagnostic()>;
+ using LLVMIRCallback =
+ function_ref<LogicalResult(llvm::Module &, DiagnosticCallback)>;
+ using ISACallback =
+ function_ref<LogicalResult(StringRef, DiagnosticCallback)>;
/// Constructor initializing the toolkit path, the list of files to link to,
/// extra command line options, the compilation target and a callback for
/// obtaining the parent symbol table. The default compilation target is
@@ -55,13 +60,10 @@ class TargetOptions {
StringRef cmdOptions = {}, StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- initialLlvmIRCallback = {},
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- linkedLlvmIRCallback = {},
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- optimizedLlvmIRCallback = {},
- function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
+ LLVMIRCallback initialLlvmIRCallback = {},
+ LLVMIRCallback linkedLlvmIRCallback = {},
+ LLVMIRCallback optimizedLlvmIRCallback = {},
+ ISACallback isaCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -100,22 +102,19 @@ class TargetOptions {
/// Returns the callback invoked with the initial LLVM IR for the device
/// module.
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- getInitialLlvmIRCallback() const;
+ LLVMIRCallback getInitialLlvmIRCallback() const;
/// Returns the callback invoked with LLVM IR for the device module
/// after linking the device libraries.
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- getLinkedLlvmIRCallback() const;
+ LLVMIRCallback getLinkedLlvmIRCallback() const;
/// Returns the callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- getOptimizedLlvmIRCallback() const;
+ LLVMIRCallback getOptimizedLlvmIRCallback() const;
/// Returns the callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<LogicalResult(Operation *op, StringRef)> getISACallback() const;
+ ISACallback getISACallback() const;
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -133,13 +132,10 @@ class TargetOptions {
StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- initialLlvmIRCallback = {},
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- linkedLlvmIRCallback = {},
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- optimizedLlvmIRCallback = {},
- function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
+ LLVMIRCallback initialLlvmIRCallback = {},
+ LLVMIRCallback linkedLlvmIRCallback = {},
+ LLVMIRCallback optimizedLlvmIRCallback = {},
+ ISACallback isaCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -162,22 +158,19 @@ class TargetOptions {
function_ref<SymbolTable *()> getSymbolTableCallback;
/// Callback invoked with the initial LLVM IR for the device module.
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- initialLlvmIRCallback;
+ LLVMIRCallback initialLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- linkedLlvmIRCallback;
+ LLVMIRCallback linkedLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- optimizedLlvmIRCallback;
+ LLVMIRCallback optimizedLlvmIRCallback;
/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<LogicalResult(Operation *op, StringRef)> isaCallback;
+ ISACallback isaCallback;
private:
TypeID typeID;
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 986b210ea7765..eb5d4f9906cb9 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -29,16 +29,17 @@ class ModuleTranslation;
/// operations being transformed must be translatable into LLVM IR.
class ModuleToObject {
public:
- ModuleToObject(
- Operation &module, StringRef triple, StringRef chip,
- StringRef features = {}, int optLevel = 3,
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- initialLlvmIRCallback = {},
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- linkedLlvmIRCallback = {},
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- optimizedLlvmIRCallback = {},
- function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
+ using DiagnosticCallback = function_ref<InFlightDiagnostic()>;
+ using LLVMIRCallback =
+ function_ref<LogicalResult(llvm::Module &, DiagnosticCallback)>;
+ using ISACallback =
+ function_ref<LogicalResult(StringRef, DiagnosticCallback)>;
+ ModuleToObject(Operation &module, StringRef triple, StringRef chip,
+ StringRef features = {}, int optLevel = 3,
+ LLVMIRCallback initialLlvmIRCallback = {},
+ LLVMIRCallback linkedLlvmIRCallback = {},
+ LLVMIRCallback optimizedLlvmIRCallback = {},
+ ISACallback isaCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -123,22 +124,19 @@ class ModuleToObject {
int optLevel;
/// Callback invoked with the initial LLVM IR for the device module.
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- initialLlvmIRCallback;
+ LLVMIRCallback initialLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- linkedLlvmIRCallback;
+ LLVMIRCallback linkedLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- optimizedLlvmIRCallback;
+ LLVMIRCallback optimizedLlvmIRCallback;
/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<LogicalResult(Operation *op, StringRef)> isaCallback;
+ ISACallback isaCallback;
private:
/// The TargetMachine created for the given Triple, if available.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index c188517ee4155..240822d1530ed 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2652,13 +2652,8 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- initialLlvmIRCallback,
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- linkedLlvmIRCallback,
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- optimizedLlvmIRCallback,
- function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
+ LLVMIRCallback initialLlvmIRCallback, LLVMIRCallback linkedLlvmIRCallback,
+ LLVMIRCallback optimizedLlvmIRCallback, ISACallback isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
@@ -2670,13 +2665,8 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- initialLlvmIRCallback,
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- linkedLlvmIRCallback,
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- optimizedLlvmIRCallback,
- function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
+ LLVMIRCallback initialLlvmIRCallback, LLVMIRCallback linkedLlvmIRCallback,
+ LLVMIRCallback optimizedLlvmIRCallback, ISACallback isaCallback)
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
@@ -2702,23 +2692,20 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}
-function_ref<LogicalResult(Operation *op, llvm::Module &)>
-TargetOptions::getInitialLlvmIRCallback() const {
+TargetOptions::LLVMIRCallback TargetOptions::getInitialLlvmIRCallback() const {
return initialLlvmIRCallback;
}
-function_ref<LogicalResult(Operation *op, llvm::Module &)>
-TargetOptions::getLinkedLlvmIRCallback() const {
+TargetOptions::LLVMIRCallback TargetOptions::getLinkedLlvmIRCallback() const {
return linkedLlvmIRCallback;
}
-function_ref<LogicalResult(Operation *op, llvm::Module &)>
+TargetOptions::LLVMIRCallback
TargetOptions::getOptimizedLlvmIRCallback() const {
return optimizedLlvmIRCallback;
}
-function_ref<LogicalResult(Operation *op, StringRef)>
-TargetOptions::getISACallback() const {
+TargetOptions::ISACallback TargetOptions::getISACallback() const {
return isaCallback;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 60d823ccb0d14..6e50c6c735662 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -34,16 +34,12 @@
using namespace mlir;
using namespace mlir::LLVM;
-ModuleToObject::ModuleToObject(
- Operation &module, StringRef triple, StringRef chip, StringRef features,
- int optLevel,
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- initialLlvmIRCallback,
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- linkedLlvmIRCallback,
- function_ref<LogicalResult(Operation *op, llvm::Module &)>
- optimizedLlvmIRCallback,
- function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
+ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
+ StringRef chip, StringRef features, int optLevel,
+ LLVMIRCallback initialLlvmIRCallback,
+ LLVMIRCallback linkedLlvmIRCallback,
+ LLVMIRCallback optimizedLlvmIRCallback,
+ ISACallback isaCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
@@ -258,8 +254,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);
+ auto diagnosticCallback = [&]() -> InFlightDiagnostic {
+ return getOperation().emitError();
+ };
+
if (initialLlvmIRCallback)
- if (failed(initialLlvmIRCallback(&getOperation(), *llvmModule)))
+ if (failed(initialLlvmIRCallback(*llvmModule, diagnosticCallback)))
return std::nullopt;
// Link bitcode files.
@@ -275,7 +275,7 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
if (linkedLlvmIRCallback)
- if (failed(linkedLlvmIRCallback(&getOperation(), *llvmModule)))
+ if (failed(linkedLlvmIRCallback(*llvmModule, diagnosticCallback)))
return std::nullopt;
// Optimize the module.
@@ -283,7 +283,7 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
return std::nullopt;
if (optimizedLlvmIRCallback)
- if (failed(optimizedLlvmIRCallback(&getOperation(), *llvmModule)))
+ if (failed(optimizedLlvmIRCallback(*llvmModule, diagnosticCallback)))
return std::nullopt;
// Return the serialized object.
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 5802f98af8cac..7d52957cdf6ac 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
return std::nullopt;
}
+ auto diagnosticCallback = [&]() -> InFlightDiagnostic {
+ return getOperation().emitError();
+ };
+
if (isaCallback)
- if (failed(isaCallback(getOperation(), serializedISA.value())))
+ if (failed(isaCallback(serializedISA.value(), diagnosticCallback)))
return std::nullopt;
#define DEBUG_TYPE "serialize-to-isa"
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index fdf8c3c72cfba..e9987f0bcf13c 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -177,16 +177,18 @@ TEST_F(MLIRTargetLLVMNVVM,
std::string initialLLVMIR;
auto initialCallback =
- [&initialLLVMIR](Operation * /*op*/,
- llvm::Module &module) -> LogicalResult {
+ [&initialLLVMIR](
+ llvm::Module &module,
+ gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
return success();
};
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
- llvm::Module &module) -> LogicalResult {
+ auto linkedCallback =
+ [&linkedLLVMIR](llvm::Module &module,
+ gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
@@ -194,16 +196,18 @@ TEST_F(MLIRTargetLLVMNVVM,
std::string optimizedLLVMIR;
auto optimizedCallback =
- [&optimizedLLVMIR](Operation * /*op*/,
- llvm::Module &module) -> LogicalResult {
+ [&optimizedLLVMIR](
+ llvm::Module &module,
+ gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
return success();
};
std::string isaResult;
- auto isaCallback = [&isaResult](Operation * /*op*/,
- llvm::StringRef isa) -> LogicalResult {
+ auto isaCallback =
+ [&isaResult](llvm::StringRef isa,
+ gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
isaResult = isa.str();
return success();
};
@@ -243,9 +247,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
ASSERT_TRUE(!!serializer);
- auto isaCallback = [](Operation * /*op*/,
- llvm::StringRef /*isa*/) -> LogicalResult {
- return failure();
+ auto isaCallback =
+ [](llvm::StringRef /*isa*/,
+ gpu::TargetOptions::DiagnosticCallback diag) -> LogicalResult {
+ return diag() << "test";
};
gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
@@ -300,8 +305,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
// Hook to intercept the LLVM IR after linking external libs.
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
- llvm::Module &module) -> LogicalResult {
+ auto linkedCallback =
+ [&linkedLLVMIR](
+ llvm::Module &module,
+ gpu::TargetOptions::DiagnosticCallback /*diag*/) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
index 97457817c32cd..4726bf8169515 100644
--- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
@@ -169,8 +169,9 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
std::string initialLLVMIR;
auto initialCallback =
- [&initialLLVMIR](Operation * /*op*/,
- llvm::Module &module) -> LogicalResult {
+ [&initialLLVMIR](
+ llvm::Module &module,
+ gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
return success();
@@ -199,8 +200,9 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
- llvm::Module &module) -> LogicalResult {
+ auto linkedCallback =
+ [&linkedLLVMIR](llvm::Module &module,
+ gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
@@ -231,8 +233,9 @@ TEST_F(MLIRTargetLLVM,
std::string optimizedLLVMIR;
auto optimizedCallback =
- [&optimizedLLVMIR](Operation * /*op*/,
- llvm::Module &module) -> LogicalResult {
+ [&optimizedLLVMIR](
+ llvm::Module &module,
+ gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
return success();
@@ -260,9 +263,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) {
IntegerAttr target = builder.getI32IntegerAttr(0);
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
- auto initialCallback = [](Operation * /*op*/,
- llvm::Module & /*module*/) -> LogicalResult {
- return failure();
+ auto initialCallback =
+ [](llvm::Module & /*module*/,
+ gpu::TargetOptions::DiagnosticCallback diag) -> LogicalResult {
+ return diag() << "test";
};
gpu::TargetOptions opts(
@@ -285,9 +289,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) {
IntegerAttr target = builder.getI32IntegerAttr(0);
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
- auto linkedCallback = [](Operation * /*op*/,
- llvm::Module & /*module*/) -> LogicalResult {
- return failure();
+ auto linkedCallback =
+ [](llvm::Module & /*module*/,
+ gpu::TargetOptions::DiagnosticCallback diag) -> LogicalResult {
+ return diag() << "test";
};
gpu::TargetOptions opts(
@@ -310,9 +315,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) {
IntegerAttr target = builder.getI32IntegerAttr(0);
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
- auto optimizedCallback = [](Operation * /*op*/,
- llvm::Module & /*module*/) -> LogicalResult {
- return failure();
+ auto optimizedCallback =
+ [](llvm::Module & /*module*/,
+ gpu::TargetOptions::DiagnosticCallback diag) -> LogicalResult {
+ return diag() << "test";
};
gpu::TargetOptions opts(
More information about the Mlir-commits
mailing list