[Mlir-commits] [mlir] [mlir] support dialect attribute translation to LLVM IR (PR #75309)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Dec 13 02:05:21 PST 2023


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/75309

Extend the `amendOperation` mechanism for translating dialect attributes attached to operations from another dialect when translating MLIR to LLVM IR. Previously, this mechanism would have no knowledge of the LLVM IR instructions created for the given operation, making it impossible for it to perform local modifications such as attaching operation-level metadata. Collect instructions inserted by the LLVM IR builder and pass them to `amendOperation`.

>From b1db027659b9b4820198889b189c9135b73ba11d Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Wed, 13 Dec 2023 10:01:26 +0000
Subject: [PATCH] [mlir] support dialect attribute translation to LLVM IR

Extend the `amendOperation` mechanism for translating dialect attributes
attached to operations from another dialect when translating MLIR to
LLVM IR. Previously, this mechanism would have no knowledge of the LLVM
IR instructions created for the given operation, making it impossible
for it to perform local modifications such as attaching operation-level
metadata. Collect instructions inserted by the LLVM IR builder and pass
them to `amendOperation`.
---
 .../Target/LLVMIR/LLVMTranslationInterface.h  |  22 ++--
 .../mlir/Target/LLVMIR/ModuleTranslation.h    | 111 +++++++++++++++++-
 .../Dialect/AMX/AMXToLLVMIRTranslation.cpp    |   2 +-
 .../ArmNeon/ArmNeonToLLVMIRTranslation.cpp    |   2 +-
 .../ArmSME/ArmSMEToLLVMIRTranslation.cpp      |   2 +-
 .../ArmSVE/ArmSVEToLLVMIRTranslation.cpp      |   2 +-
 .../Builtin/BuiltinToLLVMIRTranslation.cpp    |   2 +-
 .../Dialect/GPU/GPUToLLVMIRTranslation.cpp    |   3 +-
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |   2 +-
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  |   5 +-
 .../OpenACC/OpenACCToLLVMIRTranslation.cpp    |   6 +-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  88 +++++++-------
 .../ROCDL/ROCDLToLLVMIRTranslation.cpp        |   5 +-
 .../X86VectorToLLVMIRTranslation.cpp          |   2 +-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  29 +++--
 mlir/test/Target/LLVMIR/test.mlir             |  24 ++++
 .../Dialect/Test/TestToLLVMIRTranslation.cpp  |  22 +++-
 17 files changed, 245 insertions(+), 84 deletions(-)

diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 0531c0ec953fe..84aff94653135 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -18,12 +18,14 @@
 #include "mlir/Support/LogicalResult.h"
 
 namespace llvm {
+class Instruction;
 class IRBuilderBase;
 } // namespace llvm
 
 namespace mlir {
 namespace LLVM {
 class ModuleTranslation;
+class CapturingIRBuilder;
 } // namespace LLVM
 
 /// Base class for dialect interfaces providing translation to LLVM IR.
@@ -40,7 +42,7 @@ class LLVMTranslationDialectInterface
   /// Hook for derived dialect interface to provide translation of the
   /// operations to LLVM IR.
   virtual LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const {
     return failure();
   }
@@ -48,11 +50,13 @@ class LLVMTranslationDialectInterface
   /// Hook for derived dialect interface to act on an operation that has dialect
   /// attributes from the derived dialect (the operation itself may be from a
   /// different dialect). This gets called after the operation has been
-  /// translated. The hook is expected to use moduleTranslation to look up the
-  /// translation results and amend the corresponding IR constructs. Does
-  /// nothing and succeeds by default.
+  /// translated and accepts as second argument the list of LLVM IR instructions
+  /// that were constructed when translating the operation. The hook is expected
+  /// to use moduleTranslation to look up the translation results and amend the
+  /// corresponding IR constructs. Does nothing and succeeds by default.
   virtual LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const {
     return success();
   }
@@ -68,7 +72,7 @@ class LLVMTranslationInterface
   /// Translates the given operation to LLVM IR using the interface implemented
   /// by the op's dialect.
   virtual LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const {
     if (const LLVMTranslationDialectInterface *iface = getInterfaceFor(op))
       return iface->convertOperation(op, builder, moduleTranslation);
@@ -78,11 +82,13 @@ class LLVMTranslationInterface
   /// Acts on the given operation using the interface implemented by the dialect
   /// of one of the operation's dialect attributes.
   virtual LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const {
     if (const LLVMTranslationDialectInterface *iface =
             getInterfaceFor(attribute.getNameDialect())) {
-      return iface->amendOperation(op, attribute, moduleTranslation);
+      return iface->amendOperation(op, instructions, attribute,
+                                   moduleTranslation);
     }
     return success();
   }
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 3f797f41f10ab..72d4de332eef6 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -22,6 +22,7 @@
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
 
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
 
@@ -50,6 +51,108 @@ class DINodeAttr;
 class LLVMFuncOp;
 class ComdatSelectorOp;
 
+namespace detail {
+/// A customized inserter for LLVM's IRBuilder that captures all LLVM IR
+/// instructions that are created for future reference.
+///
+/// This is intended to be used with the `CollectionScope` RAII object:
+///
+///     llvm::IRBuilder<..., InstructionCapturingInserter> builder;
+///     {
+///       InstructionCapturingInserter::CollectionScope scope(builder);
+///       // Call IRBuilder methods as usual.
+///
+///       // This will return a list of all instructions created by the builder,
+///       // in order of creation.
+///       builder.getInserter().getCapturedInstructions();
+///     }
+///     // This will return an empty list.
+///     builder.getInserter().getCapturedInstructions();
+///
+/// The capturing functionality is _disabled_ by default for performance
+/// consideration. It needs to be explicitly enabled, which is achieved by
+/// creating a `CollectionScope`.
+class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter {
+public:
+  /// Constructs the inserter.
+  InstructionCapturingInserter()
+      : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) {
+          if (LLVM_LIKELY(enabled))
+            capturedInstructions.push_back(instruction);
+        }) {}
+
+  /// Returns the list of LLVM IR instructions captured since the last cleanup.
+  ArrayRef<llvm::Instruction *> getCapturedInstructions() const {
+    return capturedInstructions;
+  }
+
+  /// Clears the list of captured LLVM IR instructions.
+  void clearCapturedInstructions() { capturedInstructions.clear(); }
+
+  /// RAII object enabling the capture of created LLVM IR instructions.
+  class CollectionScope {
+  public:
+    /// Creates the scope for the given inserter.
+    explicit CollectionScope(InstructionCapturingInserter &reference)
+        : reference(reference) {
+      wasEnabled = reference.enabled;
+      if (wasEnabled)
+        previouslyCollectedInstructions.swap(reference.capturedInstructions);
+      reference.enable(true);
+    }
+
+    /// Creates the scope for the given `llvm::IRBuilder`.
+    template <typename Ty,
+              typename = std::enable_if_t<std::is_base_of_v<
+                  std::remove_reference_t<
+                      decltype(std::declval<Ty &>().getInserter())>,
+                  InstructionCapturingInserter>>>
+    explicit CollectionScope(Ty &builder)
+        : CollectionScope(builder.getInserter()) {}
+
+    /// Ends the scope.
+    ~CollectionScope() {
+      previouslyCollectedInstructions.swap(reference.capturedInstructions);
+      // If collection was enabled (likely in another, surrounding scope), keep
+      // the instructions collected in this scope.
+      if (wasEnabled) {
+        llvm::append_range(reference.capturedInstructions,
+                           previouslyCollectedInstructions);
+      }
+      reference.enable(wasEnabled);
+    }
+
+  private:
+    /// Back reference to the inserter.
+    InstructionCapturingInserter &reference;
+
+    /// List of instructions in the inserter prior to this scope.
+    SmallVector<llvm::Instruction *> previouslyCollectedInstructions;
+
+    /// Whether the inserter was enabled prior to this scope.
+    bool wasEnabled;
+  };
+
+  void enable(bool enabled) { this->enabled = enabled; }
+
+private:
+  /// List of captured instructions.
+  SmallVector<llvm::Instruction *> capturedInstructions;
+
+  /// Whether the collection is enabled.
+  bool enabled = false;
+};
+} // namespace detail
+
+// A class rather than a "using" declaration to support forward declarations
+// elsewhere.
+class CapturingIRBuilder
+    : public llvm::IRBuilder<llvm::ConstantFolder,
+                             detail::InstructionCapturingInserter> {
+public:
+  using IRBuilder::IRBuilder;
+};
+
 /// Implementation class for module translation. Holds a reference to the module
 /// being translated, and the mappings between the original and the translated
 /// functions, basic blocks and values. It is practically easier to hold these
@@ -209,7 +312,7 @@ class ModuleTranslation {
   /// PHI nodes are constructed for block arguments but are _not_ connected to
   /// the predecessors that may not exist yet.
   LogicalResult convertBlock(Block &bb, bool ignoreArguments,
-                             llvm::IRBuilderBase &builder);
+                             CapturingIRBuilder &builder);
 
   /// Gets the named metadata in the LLVM IR module being constructed, creating
   /// it if it does not exist.
@@ -299,7 +402,7 @@ class ModuleTranslation {
   ~ModuleTranslation();
 
   /// Converts individual components.
-  LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder);
+  LogicalResult convertOperation(Operation &op, CapturingIRBuilder &builder);
   LogicalResult convertFunctionSignatures();
   LogicalResult convertFunctions();
   LogicalResult convertComdats();
@@ -315,7 +418,9 @@ class ModuleTranslation {
   LogicalResult createTBAAMetadata();
 
   /// Translates dialect attributes attached to the given operation.
-  LogicalResult convertDialectAttributes(Operation *op);
+  LogicalResult
+  convertDialectAttributes(Operation *op,
+                           ArrayRef<llvm::Instruction *> instructions);
 
   /// Translates parameter attributes and adds them to the returned AttrBuilder.
   llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
index 044462d33cfd1..018246b362213 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
@@ -32,7 +32,7 @@ class AMXDialectLLVMIRTranslationInterface
   /// Translates the given operation to LLVM IR using the provided IR builder
   /// and saving the state in `moduleTranslation`.
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     Operation &opInst = *op;
 #include "mlir/Dialect/AMX/AMXConversions.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp
index 7098592d506e0..17c16cd00994e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp
@@ -33,7 +33,7 @@ class ArmNeonDialectLLVMIRTranslationInterface
   /// Translates the given operation to LLVM IR using the provided IR builder
   /// and saving the state in `moduleTranslation`.
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     Operation &opInst = *op;
 #include "mlir/Dialect/ArmNeon/ArmNeonConversions.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
index e6ee41188d594..fd141b52e117c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
@@ -32,7 +32,7 @@ class ArmSMEDialectLLVMIRTranslationInterface
   /// Translates the given operation to LLVM IR using the provided IR builder
   /// and saving the state in `moduleTranslation`.
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     Operation &opInst = *op;
 #include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicConversions.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp
index cd10811b68f02..325d0acc47616 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp
@@ -32,7 +32,7 @@ class ArmSVEDialectLLVMIRTranslationInterface
   /// Translates the given operation to LLVM IR using the provided IR builder
   /// and saving the state in `moduleTranslation`.
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     Operation &opInst = *op;
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEConversions.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.cpp
index 51c304cfbb8e5..04894a52a889a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.cpp
@@ -25,7 +25,7 @@ class BuiltinDialectLLVMIRTranslationInterface
   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
 
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const override {
     return success(isa<ModuleOp>(op));
   }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp
index eecc8f1001ca4..22fac6b93d4ef 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -41,7 +42,7 @@ class GPUDialectLLVMIRTranslationInterface
   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
 
   LogicalResult
-  convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *operation, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const override {
     return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
         .Case([&](gpu::GPUModuleOp) { return success(); })
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index f144c7158d679..acf43eb4e2ce9 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -410,7 +410,7 @@ class LLVMDialectLLVMIRTranslationInterface
   /// Translates the given operation to LLVM IR using the provided IR builder
   /// and saving the state in `moduleTranslation`.
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     return convertOperationImpl(*op, builder, moduleTranslation);
   }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 88e3a45534075..c61e93ef723d7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -132,7 +132,7 @@ class NVVMDialectLLVMIRTranslationInterface
   /// Translates the given operation to LLVM IR using the provided IR builder
   /// and saving the state in `moduleTranslation`.
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     Operation &opInst = *op;
 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
@@ -142,7 +142,8 @@ class NVVMDialectLLVMIRTranslationInterface
 
   /// Attaches module-level metadata for functions marked as kernels.
   LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
     auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
     if (!func)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index b964d1c082b20..56fd29ccd2237 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -261,7 +261,7 @@ processDataOperands(llvm::IRBuilderBase &builder,
 
 /// Converts an OpenACC data operation into LLVM IR.
 static LogicalResult convertDataOp(acc::DataOp &op,
-                                   llvm::IRBuilderBase &builder,
+                                   LLVM::CapturingIRBuilder &builder,
                                    LLVM::ModuleTranslation &moduleTranslation) {
   llvm::LLVMContext &ctx = builder.getContext();
   auto enclosingFuncOp = op.getOperation()->getParentOfType<LLVM::LLVMFuncOp>();
@@ -484,7 +484,7 @@ class OpenACCDialectLLVMIRTranslationInterface
   /// Translates the given operation to LLVM IR using the provided IR builder
   /// and saving the state in `moduleTranslation`.
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final;
 };
 
@@ -493,7 +493,7 @@ class OpenACCDialectLLVMIRTranslationInterface
 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR
 /// (including OpenACC runtime calls).
 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
-    Operation *op, llvm::IRBuilderBase &builder,
+    Operation *op, LLVM::CapturingIRBuilder &builder,
     LLVM::ModuleTranslation &moduleTranslation) const {
 
   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4f6200d29a70a..3b535f2ac21c8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -89,7 +89,7 @@ class OpenMPVarMappingStackFrame
 /// Find the insertion point for allocas given the current insertion point for
 /// normal operations in the builder.
 static llvm::OpenMPIRBuilder::InsertPointTy
-findAllocaInsertPoint(llvm::IRBuilderBase &builder,
+findAllocaInsertPoint(LLVM::CapturingIRBuilder &builder,
                       const LLVM::ModuleTranslation &moduleTranslation) {
   // If there is an alloca insertion point on stack, i.e. we are in a nested
   // operation and a specific point was provided by some surrounding operation,
@@ -133,7 +133,7 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
 /// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
 /// of the continuation block if provided.
 static llvm::BasicBlock *convertOmpOpRegions(
-    Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
+    Region &region, StringRef blockName, LLVM::CapturingIRBuilder &builder,
     LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus,
     SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
   llvm::BasicBlock *continuationBlock =
@@ -261,7 +261,7 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
 
 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
-convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpMaster(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
@@ -288,7 +288,7 @@ convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
 
 /// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
-convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpCritical(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   auto criticalOp = cast<omp::CriticalOp>(opInst);
@@ -400,7 +400,7 @@ collectReductionDecls(T loop,
 /// terminator. If set, `continuationBlockArgs` is populated with translated
 /// values that correspond to the values omp.yield'ed from the region.
 static LogicalResult inlineConvertOmpRegions(
-    Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
+    Region &region, StringRef blockName, LLVM::CapturingIRBuilder &builder,
     LLVM::ModuleTranslation &moduleTranslation,
     SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
   if (region.empty())
@@ -455,7 +455,8 @@ using OwningAtomicReductionGen =
 /// reduction declaration. The generator uses `builder` but ignores its
 /// insertion point.
 static OwningReductionGen
-makeReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder,
+makeReductionGen(omp::ReductionDeclareOp decl,
+                 LLVM::CapturingIRBuilder &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
   // The lambda is mutable because we need access to non-const methods of decl
   // (which aren't actually mutating it), and we must capture decl by-value to
@@ -486,7 +487,7 @@ makeReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder,
 /// reduction declaration.
 static OwningAtomicReductionGen
 makeAtomicReductionGen(omp::ReductionDeclareOp decl,
-                       llvm::IRBuilderBase &builder,
+                       LLVM::CapturingIRBuilder &builder,
                        LLVM::ModuleTranslation &moduleTranslation) {
   if (decl.getAtomicReductionRegion().empty())
     return OwningAtomicReductionGen();
@@ -514,7 +515,7 @@ makeAtomicReductionGen(omp::ReductionDeclareOp decl,
 
 /// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
-convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpOrdered(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                   LLVM::ModuleTranslation &moduleTranslation) {
   auto orderedOp = cast<omp::OrderedOp>(opInst);
 
@@ -544,7 +545,7 @@ convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
 /// Converts an OpenMP 'ordered_region' operation into LLVM IR using
 /// OpenMPIRBuilder.
 static LogicalResult
-convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpOrderedRegion(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                         LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
@@ -577,7 +578,7 @@ convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
 }
 
 static LogicalResult
-convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpSections(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   using StorableBodyGenCallbackTy =
@@ -645,7 +646,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
 
 /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
-convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
+convertOmpSingle(omp::SingleOp &singleOp, LLVM::CapturingIRBuilder &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
@@ -663,7 +664,7 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
 
 // Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
 static LogicalResult
-convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
+convertOmpTeams(omp::TeamsOp op, LLVM::CapturingIRBuilder &builder,
                 LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   LogicalResult bodyGenStatus = success();
@@ -702,7 +703,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
 
 /// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
-convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
+convertOmpTaskOp(omp::TaskOp taskOp, LLVM::CapturingIRBuilder &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   LogicalResult bodyGenStatus = success();
@@ -758,7 +759,7 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
 
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
-convertOmpTaskgroupOp(omp::TaskGroupOp tgOp, llvm::IRBuilderBase &builder,
+convertOmpTaskgroupOp(omp::TaskGroupOp tgOp, LLVM::CapturingIRBuilder &builder,
                       LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   LogicalResult bodyGenStatus = success();
@@ -780,7 +781,7 @@ convertOmpTaskgroupOp(omp::TaskGroupOp tgOp, llvm::IRBuilderBase &builder,
 /// Allocate space for privatized reduction variables.
 template <typename T>
 static void
-allocReductionVars(T loop, llvm::IRBuilderBase &builder,
+allocReductionVars(T loop, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation,
                    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
                    SmallVector<omp::ReductionDeclareOp> &reductionDecls,
@@ -803,7 +804,7 @@ allocReductionVars(T loop, llvm::IRBuilderBase &builder,
 /// Collect reduction info
 template <typename T>
 static void collectReductionInfo(
-    T loop, llvm::IRBuilderBase &builder,
+    T loop, LLVM::CapturingIRBuilder &builder,
     LLVM::ModuleTranslation &moduleTranslation,
     SmallVector<omp::ReductionDeclareOp> &reductionDecls,
     SmallVector<OwningReductionGen> &owningReductionGens,
@@ -835,7 +836,7 @@ static void collectReductionInfo(
 
 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
-convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpWsLoop(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
   auto loop = cast<omp::WsLoopOp>(opInst);
   // TODO: this should be in the op verifier instead.
@@ -1002,7 +1003,7 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
 
 /// Converts the OpenMP parallel operation to LLVM IR.
 static LogicalResult
-convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
+convertOmpParallel(omp::ParallelOp opInst, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
@@ -1124,7 +1125,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
 
 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
-convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpSimdLoop(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) {
   auto loop = cast<omp::SimdLoopOp>(opInst);
 
@@ -1232,7 +1233,7 @@ convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
 
 /// Convert omp.atomic.read operation to LLVM IR.
 static LogicalResult
-convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpAtomicRead(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
 
   auto readOp = cast<omp::AtomicReadOp>(opInst);
@@ -1255,7 +1256,7 @@ convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
 
 /// Converts an omp.atomic.write operation to LLVM IR.
 static LogicalResult
-convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpAtomicWrite(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                       LLVM::ModuleTranslation &moduleTranslation) {
   auto writeOp = cast<omp::AtomicWriteOp>(opInst);
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
@@ -1290,7 +1291,7 @@ llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
 /// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
 static LogicalResult
 convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
-                       llvm::IRBuilderBase &builder,
+                       LLVM::CapturingIRBuilder &builder,
                        LLVM::ModuleTranslation &moduleTranslation) {
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
 
@@ -1333,11 +1334,12 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
   LogicalResult updateGenStatus = success();
   auto updateFn = [&opInst, &moduleTranslation, &updateGenStatus](
                       llvm::Value *atomicx,
-                      llvm::IRBuilder<> &builder) -> llvm::Value * {
+                      llvm::IRBuilderBase &builder) -> llvm::Value * {
     Block &bb = *opInst.getRegion().begin();
     moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
     moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
-    if (failed(moduleTranslation.convertBlock(bb, true, builder))) {
+    if (failed(moduleTranslation.convertBlock(
+            bb, true, static_cast<LLVM::CapturingIRBuilder &>(builder)))) {
       updateGenStatus = (opInst.emitError()
                          << "unable to convert update operation to llvm IR");
       return nullptr;
@@ -1360,7 +1362,7 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
 
 static LogicalResult
 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
-                        llvm::IRBuilderBase &builder,
+                        LLVM::CapturingIRBuilder &builder,
                         LLVM::ModuleTranslation &moduleTranslation) {
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   mlir::Value mlirExpr;
@@ -1423,14 +1425,15 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
 
   LogicalResult updateGenStatus = success();
   auto updateFn = [&](llvm::Value *atomicx,
-                      llvm::IRBuilder<> &builder) -> llvm::Value * {
+                      llvm::IRBuilderBase &builder) -> llvm::Value * {
     if (atomicWriteOp)
       return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
     Block &bb = *atomicUpdateOp.getRegion().begin();
     moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
                                atomicx);
     moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
-    if (failed(moduleTranslation.convertBlock(bb, true, builder))) {
+    if (failed(moduleTranslation.convertBlock(
+            bb, true, static_cast<LLVM::CapturingIRBuilder &>(builder)))) {
       updateGenStatus = (atomicUpdateOp.emitError()
                          << "unable to convert update operation to llvm IR");
       return nullptr;
@@ -1457,7 +1460,7 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
 /// reduction within WsLoopOp and ParallelOp, but can be easily extended.
 static LogicalResult
 convertOmpReductionOp(omp::ReductionOp reductionOp,
-                      llvm::IRBuilderBase &builder,
+                      LLVM::CapturingIRBuilder &builder,
                       LLVM::ModuleTranslation &moduleTranslation) {
   // Find the declaration that corresponds to the reduction op.
   omp::ReductionDeclareOp declaration;
@@ -1510,7 +1513,7 @@ convertOmpReductionOp(omp::ReductionOp reductionOp,
 /// Converts an OpenMP Threadprivate operation into LLVM IR using
 /// OpenMPIRBuilder.
 static LogicalResult
-convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpThreadprivate(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                         LLVM::ModuleTranslation &moduleTranslation) {
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
   auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
@@ -1676,7 +1679,8 @@ uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) {
 // This function is somewhat equivalent to Clang's getExprTypeSize inside of
 // CGOpenMPRuntime.cpp.
 llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
-                            Operation *clauseOp, llvm::IRBuilderBase &builder,
+                            Operation *clauseOp,
+                            LLVM::CapturingIRBuilder &builder,
                             LLVM::ModuleTranslation &moduleTranslation) {
   // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
   // the size in inconsistent byte or bit format.
@@ -1727,7 +1731,7 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
                                    llvm::SmallVectorImpl<Value> &mapOperands,
                                    LLVM::ModuleTranslation &moduleTranslation,
                                    DataLayout &dl,
-                                   llvm::IRBuilderBase &builder) {
+                                   LLVM::CapturingIRBuilder &builder) {
   for (mlir::Value mapValue : mapOperands) {
     assert(mlir::isa<mlir::omp::MapInfoOp>(mapValue.getDefiningOp()) &&
            "missing map info operation or incorrect map info operation type");
@@ -1763,7 +1767,7 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
 }
 
 // Generate all map related information and fill the combinedInfo.
-static void genMapInfos(llvm::IRBuilderBase &builder,
+static void genMapInfos(LLVM::CapturingIRBuilder &builder,
                         LLVM::ModuleTranslation &moduleTranslation,
                         DataLayout &dl,
                         llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
@@ -1853,7 +1857,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
 }
 
 static LogicalResult
-convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
+convertOmpTargetData(Operation *op, LLVM::CapturingIRBuilder &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *ifCond = nullptr;
   int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
@@ -2092,7 +2096,7 @@ static bool targetOpSupported(Operation &opInst) {
 static void
 handleDeclareTargetMapVar(MapInfoData &mapData,
                           LLVM::ModuleTranslation &moduleTranslation,
-                          llvm::IRBuilderBase &builder) {
+                          LLVM::CapturingIRBuilder &builder) {
   for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
     // In the case of declare target mapped variables, the basePointer is
     // the reference pointer generated by the convertDeclareTargetAttr
@@ -2168,7 +2172,7 @@ handleDeclareTargetMapVar(MapInfoData &mapData,
 static llvm::IRBuilderBase::InsertPoint
 createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
                              llvm::Value *input, llvm::Value *&retVal,
-                             llvm::IRBuilderBase &builder,
+                             LLVM::CapturingIRBuilder &builder,
                              llvm::OpenMPIRBuilder &ompBuilder,
                              LLVM::ModuleTranslation &moduleTranslation,
                              llvm::IRBuilderBase::InsertPoint allocaIP,
@@ -2232,7 +2236,7 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
 static void
 createAlteredByCaptureMap(MapInfoData &mapData,
                           LLVM::ModuleTranslation &moduleTranslation,
-                          llvm::IRBuilderBase &builder) {
+                          LLVM::CapturingIRBuilder &builder) {
   for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
     // if it's declare target, skip it, it's handled seperately.
     if (!mapData.IsDeclareTarget[i]) {
@@ -2306,7 +2310,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
 }
 
 static LogicalResult
-convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
+convertOmpTarget(Operation &opInst, LLVM::CapturingIRBuilder &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
 
   if (!targetOpSupported(opInst))
@@ -2549,20 +2553,22 @@ class OpenMPDialectLLVMIRTranslationInterface
   /// Translates the given operation to LLVM IR using the provided IR builder
   /// and saving the state in `moduleTranslation`.
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final;
 
   /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime
   /// calls, or operation amendments
   LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final;
 };
 
 } // namespace
 
 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
-    Operation *op, NamedAttribute attribute,
+    Operation *op, ArrayRef<llvm::Instruction *> instructions,
+    NamedAttribute attribute,
     LLVM::ModuleTranslation &moduleTranslation) const {
   return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
              attribute.getName())
@@ -2652,7 +2658,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
 /// (including OpenMP runtime calls).
 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
-    Operation *op, llvm::IRBuilderBase &builder,
+    Operation *op, LLVM::CapturingIRBuilder &builder,
     LLVM::ModuleTranslation &moduleTranslation) const {
 
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 5ab70280f6c81..b6c54a44f06f2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -71,7 +71,7 @@ class ROCDLDialectLLVMIRTranslationInterface
   /// Translates the given operation to LLVM IR using the provided IR builder
   /// and saving the state in `moduleTranslation`.
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     Operation &opInst = *op;
 #include "mlir/Dialect/LLVMIR/ROCDLConversions.inc"
@@ -81,7 +81,8 @@ class ROCDLDialectLLVMIRTranslationInterface
 
   /// Attaches module-level metadata for functions marked as kernels.
   LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
     if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp
index fa5f61420ee8a..6a9fc79f8da44 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp
@@ -33,7 +33,7 @@ class X86VectorDialectLLVMIRTranslationInterface
   /// Translates the given operation to LLVM IR using the provided IR builder
   /// and saving the state in `moduleTranslation`.
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     Operation &opInst = *op;
 #include "mlir/Dialect/X86Vector/X86VectorConversions.inc"
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index d6afe354178d6..f46cb5a3dc1a4 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -52,6 +52,7 @@
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 #include <optional>
+#include <type_traits>
 
 using namespace mlir;
 using namespace mlir::LLVM;
@@ -631,9 +632,8 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
 
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.
-LogicalResult
-ModuleTranslation::convertOperation(Operation &op,
-                                    llvm::IRBuilderBase &builder) {
+LogicalResult ModuleTranslation::convertOperation(Operation &op,
+                                                  CapturingIRBuilder &builder) {
   const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op);
   if (!opIface)
     return op.emitError("cannot be converted to LLVM IR: missing "
@@ -641,11 +641,13 @@ ModuleTranslation::convertOperation(Operation &op,
                         "dialect for op: ")
            << op.getName();
 
+  detail::InstructionCapturingInserter::CollectionScope scope(builder);
   if (failed(opIface->convertOperation(&op, builder, *this)))
     return op.emitError("LLVM Translation failed for operation: ")
            << op.getName();
 
-  return convertDialectAttributes(&op);
+  return convertDialectAttributes(
+      &op, builder.getInserter().getCapturedInstructions());
 }
 
 /// Convert block to LLVM IR.  Unless `ignoreArguments` is set, emit PHI nodes
@@ -656,7 +658,7 @@ ModuleTranslation::convertOperation(Operation &op,
 /// instructions at the end of the block and leaves `builder` in a state
 /// suitable for further insertion into the end of the block.
 LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
-                                              llvm::IRBuilderBase &builder) {
+                                              CapturingIRBuilder &builder) {
   builder.SetInsertPoint(lookupBlock(&bb));
   auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
 
@@ -808,7 +810,7 @@ LogicalResult ModuleTranslation::convertGlobals() {
   // global or itself. So all global variables need to be mapped first.
   for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
     if (Block *initializer = op.getInitializerBlock()) {
-      llvm::IRBuilder<> builder(llvmModule->getContext());
+      CapturingIRBuilder builder(llvmModule->getContext());
       for (auto &op : initializer->without_terminator()) {
         if (failed(convertOperation(op, builder)) ||
             !isa<llvm::Constant>(lookupValue(op.getResult(0))))
@@ -843,7 +845,7 @@ LogicalResult ModuleTranslation::convertGlobals() {
   }
 
   for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
-    if (failed(convertDialectAttributes(op)))
+    if (failed(convertDialectAttributes(op, {})))
       return failure();
 
   // Finally, update the compile units their respective sets of global variables
@@ -994,7 +996,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   // converted before uses.
   auto blocks = getTopologicallySortedBlocks(func.getBody());
   for (Block *bb : blocks) {
-    llvm::IRBuilder<> builder(llvmContext);
+    CapturingIRBuilder builder(llvmContext);
     if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
       return failure();
   }
@@ -1004,12 +1006,13 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   detail::connectPHINodes(func.getBody(), *this);
 
   // Finally, convert dialect attributes attached to the function.
-  return convertDialectAttributes(func);
+  return convertDialectAttributes(func, {});
 }
 
-LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
+LogicalResult ModuleTranslation::convertDialectAttributes(
+    Operation *op, ArrayRef<llvm::Instruction *> instructions) {
   for (NamedAttribute attribute : op->getDialectAttrs())
-    if (failed(iface.amendOperation(op, attribute, *this)))
+    if (failed(iface.amendOperation(op, instructions, attribute, *this)))
       return failure();
   return success();
 }
@@ -1131,7 +1134,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
     // Do not convert external functions, but do process dialect attributes
     // attached to them.
     if (function.isExternal()) {
-      if (failed(convertDialectAttributes(function)))
+      if (failed(convertDialectAttributes(function, {})))
         return failure();
       continue;
     }
@@ -1434,7 +1437,7 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
   LLVM::ensureDistinctSuccessors(module);
 
   ModuleTranslation translator(module, std::move(llvmModule));
-  llvm::IRBuilder<> llvmBuilder(llvmContext);
+  CapturingIRBuilder llvmBuilder(llvmContext);
 
   // Convert module before functions and operations inside, so dialect
   // attributes can be used to change dialect-specific global configurations via
diff --git a/mlir/test/Target/LLVMIR/test.mlir b/mlir/test/Target/LLVMIR/test.mlir
index f48738f44f44b..0ab1b7267d959 100644
--- a/mlir/test/Target/LLVMIR/test.mlir
+++ b/mlir/test/Target/LLVMIR/test.mlir
@@ -16,3 +16,27 @@ module {
 module attributes {test.discardable_mod_attr = true} {}
 
 // CHECK: @sym_from_attr = external global i32
+
+// -----
+
+// CHECK-LABEL: @dialect_attr_translation
+llvm.func @dialect_attr_translation() {
+  // CHECK: ret void, !annotation ![[MD_ID:.+]]
+  llvm.return {test.add_annotation}
+}
+// CHECK: ![[MD_ID]] = !{!"annotation_from_test"}
+
+// -----
+
+// CHECK-LABEL: @dialect_attr_translation_multi
+llvm.func @dialect_attr_translation_multi(%a: i64, %b: i64, %c: i64) -> i64 {
+  // CHECK: add {{.*}}, !annotation ![[MD_ID_ADD:.+]]
+  // CHECK: mul {{.*}}, !annotation ![[MD_ID_MUL:.+]]
+  // CHECK: ret {{.*}}, !annotation ![[MD_ID_RET:.+]]
+  %ab = llvm.add %a, %b {test.add_annotation = "add"} : i64
+  %r = llvm.mul %ab, %c {test.add_annotation = "mul"} : i64
+  llvm.return {test.add_annotation = "ret"} %r : i64
+}
+// CHECK-DAG: ![[MD_ID_ADD]] = !{!"annotation_from_test: add"}
+// CHECK-DAG: ![[MD_ID_MUL]] = !{!"annotation_from_test: mul"}
+// CHECK-DAG: ![[MD_ID_RET]] = !{!"annotation_from_test: ret"}
diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
index 7110d999c8f8a..ceb8264118d3c 100644
--- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
@@ -32,18 +32,20 @@ class TestDialectLLVMIRTranslationInterface
   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
 
   LogicalResult
-  amendOperation(Operation *op, NamedAttribute attribute,
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final;
 
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *op, LLVM::CapturingIRBuilder &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final;
 };
 
 } // namespace
 
 LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
-    Operation *op, NamedAttribute attribute,
+    Operation *op, ArrayRef<llvm::Instruction *> instructions,
+    NamedAttribute attribute,
     LLVM::ModuleTranslation &moduleTranslation) const {
   return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
              attribute.getName())
@@ -70,6 +72,18 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
                     /*sym_visibility=*/nullptr);
               }
 
+              return success();
+            })
+      .Case("test.add_annotation",
+            [&](Attribute attr) {
+              for (llvm::Instruction *instruction : instructions) {
+                if (auto strAttr = dyn_cast<StringAttr>(attr)) {
+                  instruction->addAnnotationMetadata("annotation_from_test: " +
+                                                     strAttr.getValue().str());
+                } else {
+                  instruction->addAnnotationMetadata("annotation_from_test");
+                }
+              }
               return success();
             })
       .Default([](Attribute) {
@@ -79,7 +93,7 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
 }
 
 LogicalResult TestDialectLLVMIRTranslationInterface::convertOperation(
-    Operation *op, llvm::IRBuilderBase &builder,
+    Operation *op, LLVM::CapturingIRBuilder &builder,
     LLVM::ModuleTranslation &moduleTranslation) const {
   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
       // `test.symbol`s are translated into global integers in LLVM IR, with a



More information about the Mlir-commits mailing list