[Mlir-commits] [mlir] [mlir] Add callback functions for ModuleToObject (PR #116916)

Zichen Lu llvmlistbot at llvm.org
Tue Nov 19 21:22:03 PST 2024


https://github.com/MikaOvO created https://github.com/llvm/llvm-project/pull/116916

Here is the [merged MR](https://github.com/llvm/llvm-project/pull/116007) which caused a failure and [was reverted](https://github.com/llvm/llvm-project/pull/116811).

Thanks to @joker-eph for the help,  I fix it (miss constructing `ModuleObject` in `mlir/lib/Target/LLVM/NVVM/Target.cpp`) and split unit tests from origin test which don't need `ptxas` to make the test runs more widely.


>From 22007649dc7923adbc3186998d1ebdd3a8d25570 Mon Sep 17 00:00:00 2001
From: Zichen Lu <mikaovo2000 at gmail.com>
Date: Wed, 13 Nov 2024 15:35:46 +0800
Subject: [PATCH] Add callback functions for ModuleToObject

---
 .../Dialect/GPU/IR/CompilationInterfaces.h    | 44 ++++++++-
 .../include/mlir/Target/LLVM/ModuleToObject.h | 24 ++++-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 41 ++++++++-
 mlir/lib/Target/LLVM/ModuleToObject.cpp       | 22 ++++-
 mlir/lib/Target/LLVM/NVVM/Target.cpp          |  9 +-
 .../Target/LLVM/SerializeNVVMTarget.cpp       | 62 +++++++++++++
 .../Target/LLVM/SerializeToLLVMBitcode.cpp    | 92 ++++++++++++++++++-
 7 files changed, 281 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 6d7cb5ca7a7f81..d4b16a1de8eddc 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
 
 #include "mlir/IR/Attributes.h"
+#include "llvm/IR/Module.h"
 
 namespace llvm {
 class IRBuilderBase;
@@ -52,7 +53,11 @@ class TargetOptions {
       StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
       StringRef cmdOptions = {},
       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
-      function_ref<SymbolTable *()> getSymbolTableCallback = {});
+      function_ref<SymbolTable *()> getSymbolTableCallback = {},
+      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<void(StringRef)> isaCallback = {});
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -80,6 +85,22 @@ class TargetOptions {
   /// table.
   SymbolTable *getSymbolTable() const;
 
+  /// Returns the callback invoked with the initial LLVM IR for the device
+  /// module.
+  function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+
+  /// Returns the callback invoked with LLVM IR for the device module
+  /// after linking the device libraries.
+  function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+
+  /// Returns the callback invoked with LLVM IR for the device module after
+  /// LLVM optimizations but before codegen.
+  function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+
+  /// Returns the callback invoked with the target ISA for the device,
+  /// for example PTX assembly.
+  function_ref<void(StringRef)> getISACallback() const;
+
   /// Returns the default compilation target: `CompilationTarget::Fatbin`.
   static CompilationTarget getDefaultCompilationTarget();
 
@@ -90,7 +111,11 @@ class TargetOptions {
       TypeID typeID, StringRef toolkitPath = {},
       ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
-      function_ref<SymbolTable *()> getSymbolTableCallback = {});
+      function_ref<SymbolTable *()> getSymbolTableCallback = {},
+      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<void(StringRef)> isaCallback = {});
 
   /// Path to the target toolkit.
   std::string toolkitPath;
@@ -109,6 +134,21 @@ class TargetOptions {
   /// being serialized.
   function_ref<SymbolTable *()> getSymbolTableCallback;
 
+  /// Callback invoked with the initial LLVM IR for the device module.
+  function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// linking the device libraries.
+  function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// LLVM optimizations but before codegen.
+  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+  /// Callback invoked with the target ISA for the device,
+  /// for example PTX assembly.
+  function_ref<void(StringRef)> isaCallback;
+
 private:
   TypeID typeID;
 };
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index e40d7e9a43dd6b..07fc55b41ae9c5 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -29,8 +29,13 @@ class ModuleTranslation;
 /// operations being transformed must be translatable into LLVM IR.
 class ModuleToObject {
 public:
-  ModuleToObject(Operation &module, StringRef triple, StringRef chip,
-                 StringRef features = {}, int optLevel = 3);
+  ModuleToObject(
+      Operation &module, StringRef triple, StringRef chip,
+      StringRef features = {}, int optLevel = 3,
+      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<void(StringRef)> isaCallback = {});
   virtual ~ModuleToObject();
 
   /// Returns the operation being serialized.
@@ -114,6 +119,21 @@ class ModuleToObject {
   /// Optimization level.
   int optLevel;
 
+  /// Callback invoked with the initial LLVM IR for the device module.
+  function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// linking the device libraries.
+  function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// LLVM optimizations but before codegen.
+  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+  /// Callback invoked with the target ISA for the device,
+  /// for example PTX assembly.
+  function_ref<void(StringRef)> isaCallback;
+
 private:
   /// The TargetMachine created for the given Triple, if available.
   /// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 956877497d9338..d62ea72dcea2f6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
 TargetOptions::TargetOptions(
     StringRef toolkitPath, ArrayRef<std::string> linkFiles,
     StringRef cmdOptions, CompilationTarget compilationTarget,
-    function_ref<SymbolTable *()> getSymbolTableCallback)
+    function_ref<SymbolTable *()> getSymbolTableCallback,
+    function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<void(StringRef)> isaCallback)
     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
-                    cmdOptions, compilationTarget, getSymbolTableCallback) {}
+                    cmdOptions, compilationTarget, getSymbolTableCallback,
+                    initialLlvmIRCallback, linkedLlvmIRCallback,
+                    optimizedLlvmIRCallback, isaCallback) {}
 
 TargetOptions::TargetOptions(
     TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
     StringRef cmdOptions, CompilationTarget compilationTarget,
-    function_ref<SymbolTable *()> getSymbolTableCallback)
+    function_ref<SymbolTable *()> getSymbolTableCallback,
+    function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<void(StringRef)> isaCallback)
     : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
       cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
-      getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
+      getSymbolTableCallback(getSymbolTableCallback),
+      initialLlvmIRCallback(initialLlvmIRCallback),
+      linkedLlvmIRCallback(linkedLlvmIRCallback),
+      optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+      isaCallback(isaCallback), typeID(typeID) {}
 
 TypeID TargetOptions::getTypeID() const { return typeID; }
 
@@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const {
   return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
 }
 
+function_ref<void(llvm::Module &)>
+TargetOptions::getInitialLlvmIRCallback() const {
+  return initialLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getLinkedLlvmIRCallback() const {
+  return linkedLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getOptimizedLlvmIRCallback() const {
+  return optimizedLlvmIRCallback;
+}
+
+function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+  return isaCallback;
+}
+
 CompilationTarget TargetOptions::getCompilationTarget() const {
   return compilationTarget;
 }
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 77391341adaad2..3f5b3d5e31864b 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -34,10 +34,17 @@
 using namespace mlir;
 using namespace mlir::LLVM;
 
-ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
-                               StringRef chip, StringRef features, int optLevel)
+ModuleToObject::ModuleToObject(
+    Operation &module, StringRef triple, StringRef chip, StringRef features,
+    int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<void(StringRef)> isaCallback)
     : module(module), triple(triple), chip(chip), features(features),
-      optLevel(optLevel) {}
+      optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
+      linkedLlvmIRCallback(linkedLlvmIRCallback),
+      optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+      isaCallback(isaCallback) {}
 
 ModuleToObject::~ModuleToObject() = default;
 
@@ -215,6 +222,9 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
   }
   setDataLayoutAndTriple(*llvmModule);
 
+  if (initialLlvmIRCallback)
+    initialLlvmIRCallback(*llvmModule);
+
   // Link bitcode files.
   handleModulePreLink(*llvmModule);
   {
@@ -227,10 +237,16 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
     handleModulePostLink(*llvmModule);
   }
 
+  if (linkedLlvmIRCallback)
+    linkedLlvmIRCallback(*llvmModule);
+
   // Optimize the module.
   if (failed(optimizeModule(*llvmModule, optLevel)))
     return std::nullopt;
 
+  if (optimizedLlvmIRCallback)
+    optimizedLlvmIRCallback(*llvmModule);
+
   // Return the serialized object.
   return moduleToObject(*llvmModule);
 }
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 69602af8563aa0..bca26e3a0e84a9 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -86,7 +86,11 @@ SerializeGPUModuleBase::SerializeGPUModuleBase(
     Operation &module, NVVMTargetAttr target,
     const gpu::TargetOptions &targetOptions)
     : ModuleToObject(module, target.getTriple(), target.getChip(),
-                     target.getFeatures(), target.getO()),
+                     target.getFeatures(), target.getO(),
+                     targetOptions.getInitialLlvmIRCallback(),
+                     targetOptions.getLinkedLlvmIRCallback(),
+                     targetOptions.getOptimizedLlvmIRCallback(),
+                     targetOptions.getISACallback()),
       target(target), toolkitPath(targetOptions.getToolkitPath()),
       fileList(targetOptions.getLinkFiles()) {
 
@@ -572,6 +576,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
     getOperation().emitError() << "Failed translating the module to ISA.";
     return std::nullopt;
   }
+  if (isaCallback)
+    isaCallback(serializedISA.value());
+
 #define DEBUG_TYPE "serialize-to-isa"
   LLVM_DEBUG({
     llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index 642aa045178095..eb0c5358ab3530 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -156,3 +156,65 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToBinary)) {
     ASSERT_TRUE(!object->empty());
   }
 }
+
+// Test callback functions invoked with LLVM IR and ISA.
+TEST_F(MLIRTargetLLVMNVVM,
+       SKIP_WITHOUT_NVPTX(CallbackInvokedWithLLVMIRAndISA)) {
+  if (!hasPtxas())
+    GTEST_SKIP() << "PTXAS compiler not found, skipping test.";
+
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+
+  NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+
+  auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+  ASSERT_TRUE(!!serializer);
+
+  std::string initialLLVMIR;
+  auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+    llvm::raw_string_ostream ros(initialLLVMIR);
+    module.print(ros, nullptr);
+  };
+
+  std::string linkedLLVMIR;
+  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+    llvm::raw_string_ostream ros(linkedLLVMIR);
+    module.print(ros, nullptr);
+  };
+
+  std::string optimizedLLVMIR;
+  auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+    llvm::raw_string_ostream ros(optimizedLLVMIR);
+    module.print(ros, nullptr);
+  };
+
+  std::string isaResult;
+  auto isaCallback = [&isaResult](llvm::StringRef isa) {
+    isaResult = isa.str();
+  };
+
+  gpu::TargetOptions options({}, {}, {}, gpu::CompilationTarget::Binary, {},
+                             initialCallback, linkedCallback, optimizedCallback,
+                             isaCallback);
+
+  for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+    std::optional<SmallVector<char, 0>> object =
+        serializer.serializeToObject(gpuModule, options);
+
+    ASSERT_TRUE(object != std::nullopt);
+    ASSERT_TRUE(!object->empty());
+    ASSERT_TRUE(!initialLLVMIR.empty());
+    ASSERT_TRUE(!linkedLLVMIR.empty());
+    ASSERT_TRUE(!optimizedLLVMIR.empty());
+    ASSERT_TRUE(!isaResult.empty());
+
+    initialLLVMIR.clear();
+    linkedLLVMIR.clear();
+    optimizedLLVMIR.clear();
+    isaResult.clear();
+  }
+}
diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
index 0d4277ed2fdfdc..e8f2d4bd4b3f84 100644
--- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
@@ -105,7 +105,9 @@ TargetAttrImpl::serializeToObject(Attribute attribute, Operation *module,
   // Set a dummy attr to be retrieved by `createObject`.
   module->setAttr("serialize_attr", UnitAttr::get(module->getContext()));
   std::string targetTriple = llvm::sys::getProcessTriple();
-  LLVM::ModuleToObject serializer(*module, targetTriple, "", "");
+  LLVM::ModuleToObject serializer(
+      *module, targetTriple, "", "", 3, options.getInitialLlvmIRCallback(),
+      options.getLinkedLlvmIRCallback(), options.getOptimizedLlvmIRCallback());
   return serializer.run();
 }
 
@@ -153,3 +155,91 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(TargetAttrAPI)) {
   // `serializeToObject`.
   ASSERT_TRUE(properties.contains("serialize_attr"));
 }
+
+// Test callback function invoked with initial LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
+  MLIRContext context(registry);
+  context.loadAllAvailableDialects();
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  Builder builder(&context);
+  IntegerAttr target = builder.getI32IntegerAttr(0);
+  auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+  std::string initialLLVMIR;
+  auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+    llvm::raw_string_ostream ros(initialLLVMIR);
+    module.print(ros, nullptr);
+  };
+
+  gpu::TargetOptions opts(
+      {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {},
+      initialCallback);
+  std::optional<SmallVector<char, 0>> serializedBinary =
+      targetAttr.serializeToObject(*module, opts);
+
+  ASSERT_TRUE(serializedBinary != std::nullopt);
+  ASSERT_TRUE(!serializedBinary->empty());
+  ASSERT_TRUE(!initialLLVMIR.empty());
+}
+
+// Test callback function invoked with linked LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
+  MLIRContext context(registry);
+  context.loadAllAvailableDialects();
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  Builder builder(&context);
+  IntegerAttr target = builder.getI32IntegerAttr(0);
+  auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+  std::string linkedLLVMIR;
+  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+    llvm::raw_string_ostream ros(linkedLLVMIR);
+    module.print(ros, nullptr);
+  };
+
+  gpu::TargetOptions opts(
+      {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {},
+      {}, linkedCallback);
+  std::optional<SmallVector<char, 0>> serializedBinary =
+      targetAttr.serializeToObject(*module, opts);
+
+  ASSERT_TRUE(serializedBinary != std::nullopt);
+  ASSERT_TRUE(!serializedBinary->empty());
+  ASSERT_TRUE(!linkedLLVMIR.empty());
+}
+
+// Test callback function invoked with optimized LLVM IR
+TEST_F(MLIRTargetLLVM,
+       SKIP_WITHOUT_NATIVE(CallbackInvokedWithOptimizedLLVMIR)) {
+  MLIRContext context(registry);
+  context.loadAllAvailableDialects();
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  Builder builder(&context);
+  IntegerAttr target = builder.getI32IntegerAttr(0);
+  auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+  std::string optimizedLLVMIR;
+  auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+    llvm::raw_string_ostream ros(optimizedLLVMIR);
+    module.print(ros, nullptr);
+  };
+
+  gpu::TargetOptions opts(
+      {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {},
+      {}, {}, optimizedCallback);
+  std::optional<SmallVector<char, 0>> serializedBinary =
+      targetAttr.serializeToObject(*module, opts);
+
+  ASSERT_TRUE(serializedBinary != std::nullopt);
+  ASSERT_TRUE(!serializedBinary->empty());
+  ASSERT_TRUE(!optimizedLLVMIR.empty());
+}
\ No newline at end of file



More information about the Mlir-commits mailing list