[Mlir-commits] [mlir] [MLIR] Add callback functions for ModuleToObject (PR #116007)
Zichen Lu
llvmlistbot at llvm.org
Thu Nov 14 06:44:30 PST 2024
https://github.com/MikaOvO updated https://github.com/llvm/llvm-project/pull/116007
>From a6129d8584e57e0aa420b45cac9b1e5398b40a3c Mon Sep 17 00:00:00 2001
From: Zichen Lu <mikaovo2000 at gmail.com>
Date: Wed, 13 Nov 2024 15:35:46 +0800
Subject: [PATCH] Add callback functions for ModuleToObject
---
.../Dialect/GPU/IR/CompilationInterfaces.h | 44 +++++++++++-
.../include/mlir/Target/LLVM/ModuleToObject.h | 24 ++++++-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 41 +++++++++--
mlir/lib/Target/LLVM/ModuleToObject.cpp | 22 +++++-
mlir/lib/Target/LLVM/NVVM/Target.cpp | 3 +
.../Target/LLVM/SerializeNVVMTarget.cpp | 68 +++++++++++++++++++
6 files changed, 191 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 6d7cb5ca7a7f81..d4b16a1de8eddc 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
#include "mlir/IR/Attributes.h"
+#include "llvm/IR/Module.h"
namespace llvm {
class IRBuilderBase;
@@ -52,7 +53,11 @@ class TargetOptions {
StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
StringRef cmdOptions = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
- function_ref<SymbolTable *()> getSymbolTableCallback = {});
+ function_ref<SymbolTable *()> getSymbolTableCallback = {},
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -80,6 +85,22 @@ class TargetOptions {
/// table.
SymbolTable *getSymbolTable() const;
+ /// Returns the callback invoked with the initial LLVM IR for the device
+ /// module.
+ function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+
+ /// Returns the callback invoked with LLVM IR for the device module
+ /// after linking the device libraries.
+ function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+
+ /// Returns the callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+
+ /// Returns the callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> getISACallback() const;
+
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -90,7 +111,11 @@ class TargetOptions {
TypeID typeID, StringRef toolkitPath = {},
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
- function_ref<SymbolTable *()> getSymbolTableCallback = {});
+ function_ref<SymbolTable *()> getSymbolTableCallback = {},
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -109,6 +134,21 @@ class TargetOptions {
/// being serialized.
function_ref<SymbolTable *()> getSymbolTableCallback;
+ /// Callback invoked with the initial LLVM IR for the device module.
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// linking the device libraries.
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+ /// Callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> isaCallback;
+
private:
TypeID typeID;
};
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index e40d7e9a43dd6b..07fc55b41ae9c5 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -29,8 +29,13 @@ class ModuleTranslation;
/// operations being transformed must be translatable into LLVM IR.
class ModuleToObject {
public:
- ModuleToObject(Operation &module, StringRef triple, StringRef chip,
- StringRef features = {}, int optLevel = 3);
+ ModuleToObject(
+ Operation &module, StringRef triple, StringRef chip,
+ StringRef features = {}, int optLevel = 3,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -114,6 +119,21 @@ class ModuleToObject {
/// Optimization level.
int optLevel;
+ /// Callback invoked with the initial LLVM IR for the device module.
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// linking the device libraries.
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+ /// Callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> isaCallback;
+
private:
/// The TargetMachine created for the given Triple, if available.
/// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 956877497d9338..d62ea72dcea2f6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
TargetOptions::TargetOptions(
StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef cmdOptions, CompilationTarget compilationTarget,
- function_ref<SymbolTable *()> getSymbolTableCallback)
+ function_ref<SymbolTable *()> getSymbolTableCallback,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
- cmdOptions, compilationTarget, getSymbolTableCallback) {}
+ cmdOptions, compilationTarget, getSymbolTableCallback,
+ initialLlvmIRCallback, linkedLlvmIRCallback,
+ optimizedLlvmIRCallback, isaCallback) {}
TargetOptions::TargetOptions(
TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef cmdOptions, CompilationTarget compilationTarget,
- function_ref<SymbolTable *()> getSymbolTableCallback)
+ function_ref<SymbolTable *()> getSymbolTableCallback,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
- getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
+ getSymbolTableCallback(getSymbolTableCallback),
+ initialLlvmIRCallback(initialLlvmIRCallback),
+ linkedLlvmIRCallback(linkedLlvmIRCallback),
+ optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+ isaCallback(isaCallback), typeID(typeID) {}
TypeID TargetOptions::getTypeID() const { return typeID; }
@@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}
+function_ref<void(llvm::Module &)>
+TargetOptions::getInitialLlvmIRCallback() const {
+ return initialLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getLinkedLlvmIRCallback() const {
+ return linkedLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getOptimizedLlvmIRCallback() const {
+ return optimizedLlvmIRCallback;
+}
+
+function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+ return isaCallback;
+}
+
CompilationTarget TargetOptions::getCompilationTarget() const {
return compilationTarget;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 77391341adaad2..3f5b3d5e31864b 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -34,10 +34,17 @@
using namespace mlir;
using namespace mlir::LLVM;
-ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
- StringRef chip, StringRef features, int optLevel)
+ModuleToObject::ModuleToObject(
+ Operation &module, StringRef triple, StringRef chip, StringRef features,
+ int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
- optLevel(optLevel) {}
+ optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
+ linkedLlvmIRCallback(linkedLlvmIRCallback),
+ optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+ isaCallback(isaCallback) {}
ModuleToObject::~ModuleToObject() = default;
@@ -215,6 +222,9 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);
+ if (initialLlvmIRCallback)
+ initialLlvmIRCallback(*llvmModule);
+
// Link bitcode files.
handleModulePreLink(*llvmModule);
{
@@ -227,10 +237,16 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
handleModulePostLink(*llvmModule);
}
+ if (linkedLlvmIRCallback)
+ linkedLlvmIRCallback(*llvmModule);
+
// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;
+ if (optimizedLlvmIRCallback)
+ optimizedLlvmIRCallback(*llvmModule);
+
// Return the serialized object.
return moduleToObject(*llvmModule);
}
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 69602af8563aa0..2a95f343bb2f84 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -572,6 +572,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
getOperation().emitError() << "Failed translating the module to ISA.";
return std::nullopt;
}
+ if (isaCallback)
+ isaCallback(serializedISA.value());
+
#define DEBUG_TYPE "serialize-to-isa"
LLVM_DEBUG({
llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index 642aa045178095..2a7aeff149b229 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -156,3 +156,71 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToBinary)) {
ASSERT_TRUE(!object->empty());
}
}
+
+// Test callback functions invoked with LLVM IR and ISA.
+TEST_F(MLIRTargetLLVMNVVM,
+ SKIP_WITHOUT_NVPTX(CallbackInvokedWithLLVMIRAndISA)) {
+ if (!hasPtxas())
+ GTEST_SKIP() << "PTXAS compiler not found, skipping test.";
+
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+
+ NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+
+ auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+ ASSERT_TRUE(!!serializer);
+
+ std::string initialLLVMIR;
+ auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+ std::ostringstream oss;
+ llvm::raw_os_ostream ros(oss);
+ module.print(ros, nullptr);
+ initialLLVMIR = oss.str();
+ };
+
+ std::string linkedLLVMIR;
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ std::ostringstream oss;
+ llvm::raw_os_ostream ros(oss);
+ module.print(ros, nullptr);
+ linkedLLVMIR = oss.str();
+ };
+
+ std::string optimizedLLVMIR;
+ auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+ std::ostringstream oss;
+ llvm::raw_os_ostream ros(oss);
+ module.print(ros, nullptr);
+ optimizedLLVMIR = oss.str();
+ };
+
+ std::string isaResult;
+ auto isaCallback = [&isaResult](llvm::StringRef isa) {
+ isaResult = isa.str();
+ };
+
+ gpu::TargetOptions options({}, {}, {}, gpu::CompilationTarget::Binary, {},
+ initialCallback, linkedCallback, optimizedCallback,
+ isaCallback);
+
+ for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+ std::optional<SmallVector<char, 0>> object =
+ serializer.serializeToObject(gpuModule, options);
+
+ ASSERT_TRUE(object != std::nullopt);
+ ASSERT_TRUE(!object->empty());
+ ASSERT_TRUE(!initialLLVMIR.empty());
+ ASSERT_TRUE(!linkedLLVMIR.empty());
+ ASSERT_TRUE(!optimizedLLVMIR.empty());
+ ASSERT_TRUE(!isaResult.empty());
+
+ initialLLVMIR.clear();
+ linkedLLVMIR.clear();
+ optimizedLLVMIR.clear();
+ isaResult.clear();
+ }
+}
More information about the Mlir-commits
mailing list