[Mlir-commits] [mlir] [mlir] support dialect attribute translation to LLVM IR (PR #75309)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Dec 15 08:28:04 PST 2023


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/75309

>From f074a9e6b25dae4a11a3c11056506852841b90bf Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Fri, 15 Dec 2023 16:25:31 +0000
Subject: [PATCH] [mlir] support dialect attribute translation to LLVM IR

Extend the `amendOperation` mechanism for translating dialect attributes
attached to operations from another dialect when translating MLIR to
LLVM IR. Previously, this mechanism would have no knowledge of the LLVM
IR instructions created for the given operation, making it impossible
for it to perform local modifications such as attaching operation-level
metadata. Collect instructions inserted by the LLVM IR builder and pass
them to `amendOperation`.

Note that this currently doesn't extend to OpenMP operations as they go
through OpenMPIRBuilder as it lacks the required extensibility hooks.
Specifically, it owns an instance if IRBuilder with default template
arguments making it impossible to use the custom inserter class needed
for interceting LLVM IR instructions after they are created.
---
 .../Target/LLVMIR/LLVMTranslationInterface.h  |  10 +-
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |  37 ++++-
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  |   3 +-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |   6 +-
 .../ROCDL/ROCDLToLLVMIRTranslation.cpp        |   3 +-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  | 127 +++++++++++++++++-
 mlir/test/Target/LLVMIR/test.mlir             |  24 ++++
 .../Dialect/Test/TestToLLVMIRTranslation.cpp  |  18 ++-
 8 files changed, 211 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 0531c0ec953fe2..19991a6f89d80f 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -18,6 +18,7 @@
 #include "mlir/Support/LogicalResult.h"
 
 namespace llvm {
+class Instruction;
 class IRBuilderBase;
 } // namespace llvm
 
@@ -52,7 +53,8 @@ class LLVMTranslationDialectInterface
   /// translation results and amend the corresponding IR constructs. Does
   /// nothing and succeeds by default.
   virtual LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const {
     return success();
   }
@@ -78,11 +80,13 @@ class LLVMTranslationInterface
   /// Acts on the given operation using the interface implemented by the dialect
   /// of one of the operation's dialect attributes.
   virtual LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const {
     if (const LLVMTranslationDialectInterface *iface =
             getInterfaceFor(attribute.getNameDialect())) {
-      return iface->amendOperation(op, attribute, moduleTranslation);
+      return iface->amendOperation(op, instructions, attribute,
+                                   moduleTranslation);
     }
     return success();
   }
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 3f797f41f10ab0..f949ee76063da4 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -315,7 +315,9 @@ class ModuleTranslation {
   LogicalResult createTBAAMetadata();
 
   /// Translates dialect attributes attached to the given operation.
-  LogicalResult convertDialectAttributes(Operation *op);
+  LogicalResult
+  convertDialectAttributes(Operation *op,
+                           ArrayRef<llvm::Instruction *> instructions);
 
   /// Translates parameter attributes and adds them to the returned AttrBuilder.
   llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
@@ -379,6 +381,39 @@ class ModuleTranslation {
 
   /// A cache for the symbol tables constructed during symbols lookup.
   SymbolTableCollection symbolTableCollection;
+
+  /// Indicates whether the next call to convertOperation in the IR translation
+  /// process will receive an instance of CapturingIRBuilder an an instance of
+  /// llvm::IRBuilderBase. This is poor man's `isa` on the object that doesn't
+  /// have it. It is required as OpenMPIRBuilder holds its own llvm::IRBuilder
+  /// incompatible with CapturingIRBuilder and may pass it to convertOperation
+  /// when converting nested blocks.
+  // TODO: consider parameterizing OpenMPIRBuilder with the type of IR builder
+  // to use.
+  bool nestedConvertOperationUsesCapturingBuilder = false;
+
+  /// RAII object setting `nestedConvertOperationUsesCapturingBuilder` to the
+  /// given value in the given scope.
+  class CapturingBuilderFlagScope {
+  public:
+    CapturingBuilderFlagScope(ModuleTranslation &moduleTranslation,
+                              bool temporaryValue)
+        : moduleTranslation(moduleTranslation),
+          previousValue(
+              moduleTranslation.nestedConvertOperationUsesCapturingBuilder) {
+      moduleTranslation.nestedConvertOperationUsesCapturingBuilder =
+          temporaryValue;
+    }
+
+    ~CapturingBuilderFlagScope() {
+      moduleTranslation.nestedConvertOperationUsesCapturingBuilder =
+          previousValue;
+    }
+
+  private:
+    ModuleTranslation &moduleTranslation;
+    bool previousValue;
+  };
 };
 
 namespace detail {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 88e3a455340750..0d6bca5e2203ea 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -142,7 +142,8 @@ class NVVMDialectLLVMIRTranslationInterface
 
   /// Attaches module-level metadata for functions marked as kernels.
   LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
     auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
     if (!func)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4f6200d29a70a6..28b1ad1d2f8b4a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2555,14 +2555,16 @@ class OpenMPDialectLLVMIRTranslationInterface
   /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime
   /// calls, or operation amendments
   LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final;
 };
 
 } // namespace
 
 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
-    Operation *op, NamedAttribute attribute,
+    Operation *op, ArrayRef<llvm::Instruction *> instructions,
+    NamedAttribute attribute,
     LLVM::ModuleTranslation &moduleTranslation) const {
   return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
              attribute.getName())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 5ab70280f6c818..55a6285ec87eb4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -81,7 +81,8 @@ class ROCDLDialectLLVMIRTranslationInterface
 
   /// Attaches module-level metadata for functions marked as kernels.
   LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
     if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index d6afe354178d66..32f613a8a8abe5 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -59,6 +59,113 @@ using namespace mlir::LLVM::detail;
 
 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
 
+namespace {
+/// A customized inserter for LLVM's IRBuilder that captures all LLVM IR
+/// instructions that are created for future reference.
+///
+/// This is intended to be used with the `CollectionScope` RAII object:
+///
+///     llvm::IRBuilder<..., InstructionCapturingInserter> builder;
+///     {
+///       InstructionCapturingInserter::CollectionScope scope(builder);
+///       // Call IRBuilder methods as usual.
+///
+///       // This will return a list of all instructions created by the builder,
+///       // in order of creation.
+///       builder.getInserter().getCapturedInstructions();
+///     }
+///     // This will return an empty list.
+///     builder.getInserter().getCapturedInstructions();
+///
+/// The capturing functionality is _disabled_ by default for performance
+/// consideration. It needs to be explicitly enabled, which is achieved by
+/// creating a `CollectionScope`.
+class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter {
+public:
+  /// Constructs the inserter.
+  InstructionCapturingInserter()
+      : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) {
+          if (LLVM_LIKELY(enabled))
+            capturedInstructions.push_back(instruction);
+        }) {}
+
+  /// Returns the list of LLVM IR instructions captured since the last cleanup.
+  ArrayRef<llvm::Instruction *> getCapturedInstructions() const {
+    return capturedInstructions;
+  }
+
+  /// Clears the list of captured LLVM IR instructions.
+  void clearCapturedInstructions() { capturedInstructions.clear(); }
+
+  /// RAII object enabling the capture of created LLVM IR instructions.
+  class CollectionScope {
+  public:
+    /// Creates the scope for the given inserter.
+    CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing);
+
+    /// Ends the scope.
+    ~CollectionScope();
+
+    ArrayRef<llvm::Instruction *> getCapturedInstructions() {
+      if (!inserter)
+        return {};
+      return inserter->getCapturedInstructions();
+    }
+
+  private:
+    /// Back reference to the inserter.
+    InstructionCapturingInserter *inserter = nullptr;
+
+    /// List of instructions in the inserter prior to this scope.
+    SmallVector<llvm::Instruction *> previouslyCollectedInstructions;
+
+    /// Whether the inserter was enabled prior to this scope.
+    bool wasEnabled;
+  };
+
+  /// Enable or disable the capturing mechanism.
+  void setEnabled(bool enabled = true) { this->enabled = enabled; }
+
+private:
+  /// List of captured instructions.
+  SmallVector<llvm::Instruction *> capturedInstructions;
+
+  /// Whether the collection is enabled.
+  bool enabled = false;
+};
+
+using CapturingIRBuilder =
+    llvm::IRBuilder<llvm::ConstantFolder, InstructionCapturingInserter>;
+} // namespace
+
+InstructionCapturingInserter::CollectionScope::CollectionScope(
+    llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) {
+
+  if (!isBuilderCapturing)
+    return;
+
+  auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder);
+  inserter = &capturingIRBuilder.getInserter();
+  wasEnabled = inserter->enabled;
+  if (wasEnabled)
+    previouslyCollectedInstructions.swap(inserter->capturedInstructions);
+  inserter->setEnabled(true);
+}
+
+InstructionCapturingInserter::CollectionScope::~CollectionScope() {
+  if (!inserter)
+    return;
+
+  previouslyCollectedInstructions.swap(inserter->capturedInstructions);
+  // If collection was enabled (likely in another, surrounding scope), keep
+  // the instructions collected in this scope.
+  if (wasEnabled) {
+    llvm::append_range(inserter->capturedInstructions,
+                       previouslyCollectedInstructions);
+  }
+  inserter->setEnabled(wasEnabled);
+}
+
 /// Translates the given data layout spec attribute to the LLVM IR data layout.
 /// Only integer, float, pointer and endianness entries are currently supported.
 static FailureOr<llvm::DataLayout>
@@ -641,11 +748,15 @@ ModuleTranslation::convertOperation(Operation &op,
                         "dialect for op: ")
            << op.getName();
 
+  InstructionCapturingInserter::CollectionScope scope(
+      builder, nestedConvertOperationUsesCapturingBuilder);
+  CapturingBuilderFlagScope flagScope(*this, false);
+
   if (failed(opIface->convertOperation(&op, builder, *this)))
     return op.emitError("LLVM Translation failed for operation: ")
            << op.getName();
 
-  return convertDialectAttributes(&op);
+  return convertDialectAttributes(&op, scope.getCapturedInstructions());
 }
 
 /// Convert block to LLVM IR.  Unless `ignoreArguments` is set, emit PHI nodes
@@ -843,7 +954,7 @@ LogicalResult ModuleTranslation::convertGlobals() {
   }
 
   for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
-    if (failed(convertDialectAttributes(op)))
+    if (failed(convertDialectAttributes(op, {})))
       return failure();
 
   // Finally, update the compile units their respective sets of global variables
@@ -994,7 +1105,8 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   // converted before uses.
   auto blocks = getTopologicallySortedBlocks(func.getBody());
   for (Block *bb : blocks) {
-    llvm::IRBuilder<> builder(llvmContext);
+    CapturingIRBuilder builder(llvmContext);
+    CapturingBuilderFlagScope scope(*this, true);
     if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
       return failure();
   }
@@ -1004,12 +1116,13 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   detail::connectPHINodes(func.getBody(), *this);
 
   // Finally, convert dialect attributes attached to the function.
-  return convertDialectAttributes(func);
+  return convertDialectAttributes(func, {});
 }
 
-LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
+LogicalResult ModuleTranslation::convertDialectAttributes(
+    Operation *op, ArrayRef<llvm::Instruction *> instructions) {
   for (NamedAttribute attribute : op->getDialectAttrs())
-    if (failed(iface.amendOperation(op, attribute, *this)))
+    if (failed(iface.amendOperation(op, instructions, attribute, *this)))
       return failure();
   return success();
 }
@@ -1131,7 +1244,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
     // Do not convert external functions, but do process dialect attributes
     // attached to them.
     if (function.isExternal()) {
-      if (failed(convertDialectAttributes(function)))
+      if (failed(convertDialectAttributes(function, {})))
         return failure();
       continue;
     }
diff --git a/mlir/test/Target/LLVMIR/test.mlir b/mlir/test/Target/LLVMIR/test.mlir
index f48738f44f44b4..0ab1b7267d9598 100644
--- a/mlir/test/Target/LLVMIR/test.mlir
+++ b/mlir/test/Target/LLVMIR/test.mlir
@@ -16,3 +16,27 @@ module {
 module attributes {test.discardable_mod_attr = true} {}
 
 // CHECK: @sym_from_attr = external global i32
+
+// -----
+
+// CHECK-LABEL: @dialect_attr_translation
+llvm.func @dialect_attr_translation() {
+  // CHECK: ret void, !annotation ![[MD_ID:.+]]
+  llvm.return {test.add_annotation}
+}
+// CHECK: ![[MD_ID]] = !{!"annotation_from_test"}
+
+// -----
+
+// CHECK-LABEL: @dialect_attr_translation_multi
+llvm.func @dialect_attr_translation_multi(%a: i64, %b: i64, %c: i64) -> i64 {
+  // CHECK: add {{.*}}, !annotation ![[MD_ID_ADD:.+]]
+  // CHECK: mul {{.*}}, !annotation ![[MD_ID_MUL:.+]]
+  // CHECK: ret {{.*}}, !annotation ![[MD_ID_RET:.+]]
+  %ab = llvm.add %a, %b {test.add_annotation = "add"} : i64
+  %r = llvm.mul %ab, %c {test.add_annotation = "mul"} : i64
+  llvm.return {test.add_annotation = "ret"} %r : i64
+}
+// CHECK-DAG: ![[MD_ID_ADD]] = !{!"annotation_from_test: add"}
+// CHECK-DAG: ![[MD_ID_MUL]] = !{!"annotation_from_test: mul"}
+// CHECK-DAG: ![[MD_ID_RET]] = !{!"annotation_from_test: ret"}
diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
index 7110d999c8f8ae..2dd99c67c1439b 100644
--- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
@@ -32,7 +32,8 @@ class TestDialectLLVMIRTranslationInterface
   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
 
   LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final;
 
   LogicalResult
@@ -43,7 +44,8 @@ class TestDialectLLVMIRTranslationInterface
 } // namespace
 
 LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
-    Operation *op, NamedAttribute attribute,
+    Operation *op, ArrayRef<llvm::Instruction *> instructions,
+    NamedAttribute attribute,
     LLVM::ModuleTranslation &moduleTranslation) const {
   return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
              attribute.getName())
@@ -70,6 +72,18 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
                     /*sym_visibility=*/nullptr);
               }
 
+              return success();
+            })
+      .Case("test.add_annotation",
+            [&](Attribute attr) {
+              for (llvm::Instruction *instruction : instructions) {
+                if (auto strAttr = dyn_cast<StringAttr>(attr)) {
+                  instruction->addAnnotationMetadata("annotation_from_test: " +
+                                                     strAttr.getValue().str());
+                } else {
+                  instruction->addAnnotationMetadata("annotation_from_test");
+                }
+              }
               return success();
             })
       .Default([](Attribute) {



More information about the Mlir-commits mailing list