[Mlir-commits] [mlir] [mlir][gpu] Add a symbol table field to TargetOptions and adjust GpuModuleToBinary (PR #65797)

Fabian Mora llvmlistbot at llvm.org
Sat Sep 9 16:44:24 PDT 2023


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/65797:

>From 48ee1a33b156d74c82b1f91167678cdcfe3892f3 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Fri, 8 Sep 2023 19:32:24 +0000
Subject: [PATCH 1/3] [mlir][gpu] Add a symbol table field to TargetOptions and
 adjust GpuModuleToBinary

This patch adds the option of building an optional symbol table for the top
operation in the `gpu-module-to-binary` pass. This table is required by some
target attributes. The table is not created by default, as other targets don't
need it. The table is passed through `TargetOptions`.
---
 .../mlir/Dialect/GPU/IR/CompilationInterfaces.h    | 14 ++++++++++++--
 mlir/include/mlir/Dialect/GPU/Transforms/Passes.td | 13 +++++++++----
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp             | 12 ++++++++----
 mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp |  9 ++++++++-
 4 files changed, 37 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index e0bf560dbd98b92..a1683ca477fa9d1 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -20,6 +20,7 @@ class IRBuilderBase;
 }
 
 namespace mlir {
+class SymbolTable;
 namespace LLVM {
 class ModuleTranslation;
 }
@@ -59,7 +60,8 @@ class TargetOptions {
   /// compilation target is `binary`.
   TargetOptions(StringRef toolkitPath = {},
                 ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
-                CompilationTarget compilationTarget = binOrFatbin);
+                CompilationTarget compilationTarget = binOrFatbin,
+                SymbolTable *parentTable = nullptr);
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -80,12 +82,16 @@ class TargetOptions {
   /// Returns the compilation target.
   CompilationTarget getCompilationTarget() const;
 
+  /// Returns the provided parent symbol table.
+  SymbolTable *getParentTable() const;
+
 protected:
   /// Derived classes must use this constructor to initialize `typeID` to the
   /// appropiate value: ie. `TargetOptions(TypeID::get<DerivedClass>())`.
   TargetOptions(TypeID typeID, StringRef toolkitPath = {},
                 ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
-                CompilationTarget compilationTarget = binOrFatbin);
+                CompilationTarget compilationTarget = binOrFatbin,
+                SymbolTable *parentTable = nullptr);
 
   /// Path to the target toolkit.
   std::string toolkitPath;
@@ -100,6 +106,10 @@ class TargetOptions {
   /// Compilation process target representation.
   CompilationTarget compilationTarget;
 
+  /// Parent symbol table of all the GPU modules being serialized. By default
+  /// this member is null as it is not required by most targets.
+  SymbolTable *parentTable;
+
 private:
   TypeID typeID;
 };
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index ba8a6266604e46c..4588cf6ef4b310a 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -64,9 +64,11 @@ def GpuModuleToBinaryPass
     with an object for every target.
 
     The `format` argument can have the following values:
-    1. `offloading`, `llvm`: producing an offloading representation.
-    2. `assembly`, `isa`: producing assembly code.
-    3. `binary`, `bin`: producing binaries.
+    1. `offloading`, `llvm`: produces an offloading representation.
+    2. `assembly`, `isa`: produces assembly code.
+    3. `binary`, `bin`: produces binaries.
+    4. `fatbinary`, `fatbin`: produces fatbinaries.
+    5. `binOrFatbin`: produces bins or fatbins, the target decides which.
   }];
   let options = [
     Option<"offloadingHandler", "handler", "Attribute", "nullptr",
@@ -78,7 +80,10 @@ def GpuModuleToBinaryPass
     Option<"cmdOptions", "opts", "std::string", [{""}],
            "Command line options to pass to the tools.">,
     Option<"compilationTarget", "format", "std::string", [{"binOrFatbin"}],
-           "The target representation of the compilation process.">
+           "The target representation of the compilation process.">,
+    Option<"constructSymbolTable", "symbol-table", "bool", [{false}],
+           "Enable building a symbol table enclosing the modules to serialize."
+           "Most targets can safely disable this.">
   ];
 }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index f417a083337fcaf..a53a120c7ef8a84 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1982,17 +1982,19 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 TargetOptions::TargetOptions(StringRef toolkitPath,
                              ArrayRef<std::string> linkFiles,
                              StringRef cmdOptions,
-                             CompilationTarget compilationTarget)
+                             CompilationTarget compilationTarget,
+                             SymbolTable *parentTable)
     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
-                    cmdOptions, compilationTarget) {}
+                    cmdOptions, compilationTarget, parentTable) {}
 
 TargetOptions::TargetOptions(TypeID typeID, StringRef toolkitPath,
                              ArrayRef<std::string> linkFiles,
                              StringRef cmdOptions,
-                             CompilationTarget compilationTarget)
+                             CompilationTarget compilationTarget,
+                             SymbolTable *parentTable)
     : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
       cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
-      typeID(typeID) {}
+      parentTable(parentTable), typeID(typeID) {}
 
 TypeID TargetOptions::getTypeID() const { return typeID; }
 
@@ -2002,6 +2004,8 @@ ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
 
 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
 
+SymbolTable *TargetOptions::getParentTable() const { return parentTable; }
+
 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
 TargetOptions::tokenizeCmdOptions() const {
   std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
index 06b7dee6941e1f4..1eb9cf7f941d0ba 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
@@ -66,9 +66,16 @@ void GpuModuleToBinaryPass::runOnOperation() {
                          .Default(-1);
   if (targetFormat == -1)
     getOperation()->emitError() << "Invalid format specified.";
+
+  std::unique_ptr<SymbolTable> parentTable;
+  // Create the symbol table if it was requested in the pass options.
+  if (constructSymbolTable)
+    parentTable = std::unique_ptr<SymbolTable>(new SymbolTable(getOperation()));
+
   TargetOptions targetOptions(
       toolkitPath, linkFiles, cmdOptions,
-      static_cast<TargetOptions::CompilationTarget>(targetFormat));
+      static_cast<TargetOptions::CompilationTarget>(targetFormat),
+      parentTable.get());
   if (failed(transformGpuModulesToBinaries(
           getOperation(),
           offloadingHandler ? dyn_cast<OffloadingLLVMTranslationAttrInterface>(

>From 24cd7f56d29baa11c138dfdca0bfbee782d3728e Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Sat, 9 Sep 2023 16:35:11 +0000
Subject: [PATCH 2/3] Switch to a lazy table builder.

---
 .../Dialect/GPU/IR/CompilationInterfaces.h    | 15 +++++++-------
 .../mlir/Dialect/GPU/Transforms/Passes.td     |  5 +----
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 12 ++++++-----
 .../Dialect/GPU/Transforms/ModuleToBinary.cpp | 20 ++++++++++++++-----
 4 files changed, 31 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index a1683ca477fa9d1..cd2287ecf239d36 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -57,11 +57,11 @@ class TargetOptions {
 
   /// Constructor initializing the toolkit path, the list of files to link to,
   /// extra command line options & the compilation target. The default
-  /// compilation target is `binary`.
+  /// compilation target is `binOrFatbin`.
   TargetOptions(StringRef toolkitPath = {},
                 ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
                 CompilationTarget compilationTarget = binOrFatbin,
-                SymbolTable *parentTable = nullptr);
+                function_ref<SymbolTable *()> symbolTableCallback = {});
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -82,7 +82,8 @@ class TargetOptions {
   /// Returns the compilation target.
   CompilationTarget getCompilationTarget() const;
 
-  /// Returns the provided parent symbol table.
+  /// Returns the parent symbol table if a callback was provided, else returns
+  /// nullptr.
   SymbolTable *getParentTable() const;
 
 protected:
@@ -91,7 +92,7 @@ class TargetOptions {
   TargetOptions(TypeID typeID, StringRef toolkitPath = {},
                 ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
                 CompilationTarget compilationTarget = binOrFatbin,
-                SymbolTable *parentTable = nullptr);
+                function_ref<SymbolTable *()> symbolTableCallback = {});
 
   /// Path to the target toolkit.
   std::string toolkitPath;
@@ -106,9 +107,9 @@ class TargetOptions {
   /// Compilation process target representation.
   CompilationTarget compilationTarget;
 
-  /// Parent symbol table of all the GPU modules being serialized. By default
-  /// this member is null as it is not required by most targets.
-  SymbolTable *parentTable;
+  /// Callback for obtaining the parent symbol table of all the GPU modules
+  /// being serialized.
+  function_ref<SymbolTable *()> symbolTableCallback;
 
 private:
   TypeID typeID;
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 4588cf6ef4b310a..0bfb2750992058f 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -80,10 +80,7 @@ def GpuModuleToBinaryPass
     Option<"cmdOptions", "opts", "std::string", [{""}],
            "Command line options to pass to the tools.">,
     Option<"compilationTarget", "format", "std::string", [{"binOrFatbin"}],
-           "The target representation of the compilation process.">,
-    Option<"constructSymbolTable", "symbol-table", "bool", [{false}],
-           "Enable building a symbol table enclosing the modules to serialize."
-           "Most targets can safely disable this.">
+           "The target representation of the compilation process.">
   ];
 }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index a53a120c7ef8a84..68423e74b7f328f 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1983,18 +1983,18 @@ TargetOptions::TargetOptions(StringRef toolkitPath,
                              ArrayRef<std::string> linkFiles,
                              StringRef cmdOptions,
                              CompilationTarget compilationTarget,
-                             SymbolTable *parentTable)
+                             function_ref<SymbolTable *()> symbolTableCallback)
     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
-                    cmdOptions, compilationTarget, parentTable) {}
+                    cmdOptions, compilationTarget, symbolTableCallback) {}
 
 TargetOptions::TargetOptions(TypeID typeID, StringRef toolkitPath,
                              ArrayRef<std::string> linkFiles,
                              StringRef cmdOptions,
                              CompilationTarget compilationTarget,
-                             SymbolTable *parentTable)
+                             function_ref<SymbolTable *()> symbolTableCallback)
     : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
       cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
-      parentTable(parentTable), typeID(typeID) {}
+      symbolTableCallback(symbolTableCallback), typeID(typeID) {}
 
 TypeID TargetOptions::getTypeID() const { return typeID; }
 
@@ -2004,7 +2004,9 @@ ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
 
 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
 
-SymbolTable *TargetOptions::getParentTable() const { return parentTable; }
+SymbolTable *TargetOptions::getParentTable() const {
+  return symbolTableCallback ? symbolTableCallback() : nullptr;
+}
 
 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
 TargetOptions::tokenizeCmdOptions() const {
diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
index 1eb9cf7f941d0ba..e29a1f0c3248d04 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
@@ -67,15 +67,25 @@ void GpuModuleToBinaryPass::runOnOperation() {
   if (targetFormat == -1)
     getOperation()->emitError() << "Invalid format specified.";
 
-  std::unique_ptr<SymbolTable> parentTable;
-  // Create the symbol table if it was requested in the pass options.
-  if (constructSymbolTable)
-    parentTable = std::unique_ptr<SymbolTable>(new SymbolTable(getOperation()));
+  // Lazy symbol table builder callback.
+  std::optional<SymbolTable> parentTable;
+  auto lazyTableBuilder = [&]() -> SymbolTable * {
+    // Build the table if it has not been built.
+    if (!parentTable) {
+      Operation *table = SymbolTable::getNearestSymbolTable(getOperation());
+      // It's up to the target attribute to determine if failing to find a
+      // symbol table is an error.
+      if (!table)
+        return nullptr;
+      parentTable = SymbolTable(table);
+    }
+    return &parentTable.value();
+  };
 
   TargetOptions targetOptions(
       toolkitPath, linkFiles, cmdOptions,
       static_cast<TargetOptions::CompilationTarget>(targetFormat),
-      parentTable.get());
+      lazyTableBuilder);
   if (failed(transformGpuModulesToBinaries(
           getOperation(),
           offloadingHandler ? dyn_cast<OffloadingLLVMTranslationAttrInterface>(

>From 7600c49c44cbf3a335558b728bbf79bf64540237 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Sat, 9 Sep 2023 23:43:41 +0000
Subject: [PATCH 3/3] Addressing the comments: fixing docs, etc.

---
 .../Dialect/GPU/IR/CompilationInterfaces.h    | 20 ++++++++------
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 26 +++++++++----------
 2 files changed, 24 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index cd2287ecf239d36..a1f64be57fa699d 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -56,12 +56,13 @@ class TargetOptions {
   } CompilationTarget;
 
   /// Constructor initializing the toolkit path, the list of files to link to,
-  /// extra command line options & the compilation target. The default
-  /// compilation target is `binOrFatbin`.
+  /// extra command line options, the compilation target and a callback for
+  /// obtaining the parent symbol table. The default compilation target is
+  /// `binOrFatbin`.
   TargetOptions(StringRef toolkitPath = {},
                 ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
                 CompilationTarget compilationTarget = binOrFatbin,
-                function_ref<SymbolTable *()> symbolTableCallback = {});
+                function_ref<SymbolTable *()> getSymbolTableCallback = {});
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -82,9 +83,12 @@ class TargetOptions {
   /// Returns the compilation target.
   CompilationTarget getCompilationTarget() const;
 
-  /// Returns the parent symbol table if a callback was provided, else returns
-  /// nullptr.
-  SymbolTable *getParentTable() const;
+  /// Returns the result of the `getSymbolTableCallback` callback or a nullptr
+  /// if no callback was provided.
+  /// Note: The callback itself can return nullptr. It is up to the target how
+  /// to react to getting a nullptr, e.g., emitting an error or constructing the
+  /// table.
+  SymbolTable *getSymbolTable() const;
 
 protected:
   /// Derived classes must use this constructor to initialize `typeID` to the
@@ -92,7 +96,7 @@ class TargetOptions {
   TargetOptions(TypeID typeID, StringRef toolkitPath = {},
                 ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
                 CompilationTarget compilationTarget = binOrFatbin,
-                function_ref<SymbolTable *()> symbolTableCallback = {});
+                function_ref<SymbolTable *()> getSymbolTableCallback = {});
 
   /// Path to the target toolkit.
   std::string toolkitPath;
@@ -109,7 +113,7 @@ class TargetOptions {
 
   /// Callback for obtaining the parent symbol table of all the GPU modules
   /// being serialized.
-  function_ref<SymbolTable *()> symbolTableCallback;
+  function_ref<SymbolTable *()> getSymbolTableCallback;
 
 private:
   TypeID typeID;
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 68423e74b7f328f..46fb1766bc40535 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1979,22 +1979,20 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 // GPU target options
 //===----------------------------------------------------------------------===//
 
-TargetOptions::TargetOptions(StringRef toolkitPath,
-                             ArrayRef<std::string> linkFiles,
-                             StringRef cmdOptions,
-                             CompilationTarget compilationTarget,
-                             function_ref<SymbolTable *()> symbolTableCallback)
+TargetOptions::TargetOptions(
+    StringRef toolkitPath, ArrayRef<std::string> linkFiles,
+    StringRef cmdOptions, CompilationTarget compilationTarget,
+    function_ref<SymbolTable *()> getSymbolTableCallback)
     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
-                    cmdOptions, compilationTarget, symbolTableCallback) {}
+                    cmdOptions, compilationTarget, getSymbolTableCallback) {}
 
-TargetOptions::TargetOptions(TypeID typeID, StringRef toolkitPath,
-                             ArrayRef<std::string> linkFiles,
-                             StringRef cmdOptions,
-                             CompilationTarget compilationTarget,
-                             function_ref<SymbolTable *()> symbolTableCallback)
+TargetOptions::TargetOptions(
+    TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
+    StringRef cmdOptions, CompilationTarget compilationTarget,
+    function_ref<SymbolTable *()> getSymbolTableCallback)
     : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
       cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
-      symbolTableCallback(symbolTableCallback), typeID(typeID) {}
+      getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
 
 TypeID TargetOptions::getTypeID() const { return typeID; }
 
@@ -2004,8 +2002,8 @@ ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
 
 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
 
-SymbolTable *TargetOptions::getParentTable() const {
-  return symbolTableCallback ? symbolTableCallback() : nullptr;
+SymbolTable *TargetOptions::getSymbolTable() const {
+  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
 }
 
 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>



More information about the Mlir-commits mailing list