[Mlir-commits] [mlir] 9519e3e - [mlir] support dialect attribute translation to LLVM IR (#75309)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 19 05:18:20 PST 2023
Author: Oleksandr "Alex" Zinenko
Date: 2023-12-19T14:18:16+01:00
New Revision: 9519e3ecbf6ed251c5ab7c74549fe86df1efc14c
URL: https://github.com/llvm/llvm-project/commit/9519e3ecbf6ed251c5ab7c74549fe86df1efc14c
DIFF: https://github.com/llvm/llvm-project/commit/9519e3ecbf6ed251c5ab7c74549fe86df1efc14c.diff
LOG: [mlir] support dialect attribute translation to LLVM IR (#75309)
Extend the `amendOperation` mechanism for translating dialect attributes
attached to operations from another dialect when translating MLIR to
LLVM IR. Previously, this mechanism would have no knowledge of the LLVM
IR instructions created for the given operation, making it impossible
for it to perform local modifications such as attaching operation-level
metadata. Collect instructions inserted by the LLVM IR builder and pass
them to `amendOperation`.
Added:
Modified:
mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/LLVMIR/test.mlir
mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 0531c0ec953fe2..19991a6f89d80f 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -18,6 +18,7 @@
#include "mlir/Support/LogicalResult.h"
namespace llvm {
+class Instruction;
class IRBuilderBase;
} // namespace llvm
@@ -52,7 +53,8 @@ class LLVMTranslationDialectInterface
/// translation results and amend the corresponding IR constructs. Does
/// nothing and succeeds by default.
virtual LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
return success();
}
@@ -78,11 +80,13 @@ class LLVMTranslationInterface
/// Acts on the given operation using the interface implemented by the dialect
/// of one of the operation's dialect attributes.
virtual LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
if (const LLVMTranslationDialectInterface *iface =
getInterfaceFor(attribute.getNameDialect())) {
- return iface->amendOperation(op, attribute, moduleTranslation);
+ return iface->amendOperation(op, instructions, attribute,
+ moduleTranslation);
}
return success();
}
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 3f797f41f10ab0..d6b03aca28d24d 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -209,7 +209,10 @@ class ModuleTranslation {
/// PHI nodes are constructed for block arguments but are _not_ connected to
/// the predecessors that may not exist yet.
LogicalResult convertBlock(Block &bb, bool ignoreArguments,
- llvm::IRBuilderBase &builder);
+ llvm::IRBuilderBase &builder) {
+ return convertBlockImpl(bb, ignoreArguments, builder,
+ /*recordInsertions=*/false);
+ }
/// Gets the named metadata in the LLVM IR module being constructed, creating
/// it if it does not exist.
@@ -299,12 +302,16 @@ class ModuleTranslation {
~ModuleTranslation();
/// Converts individual components.
- LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder);
+ LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder,
+ bool recordInsertions = false);
LogicalResult convertFunctionSignatures();
LogicalResult convertFunctions();
LogicalResult convertComdats();
LogicalResult convertGlobals();
LogicalResult convertOneFunction(LLVMFuncOp func);
+ LogicalResult convertBlockImpl(Block &bb, bool ignoreArguments,
+ llvm::IRBuilderBase &builder,
+ bool recordInsertions);
/// Returns the LLVM metadata corresponding to the given mlir LLVM dialect
/// TBAATagAttr.
@@ -315,7 +322,9 @@ class ModuleTranslation {
LogicalResult createTBAAMetadata();
/// Translates dialect attributes attached to the given operation.
- LogicalResult convertDialectAttributes(Operation *op);
+ LogicalResult
+ convertDialectAttributes(Operation *op,
+ ArrayRef<llvm::Instruction *> instructions);
/// Translates parameter attributes and adds them to the returned AttrBuilder.
llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 88e3a455340750..0d6bca5e2203ea 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -142,7 +142,8 @@ class NVVMDialectLLVMIRTranslationInterface
/// Attaches module-level metadata for functions marked as kernels.
LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 088e7ae4231bef..629584683f4991 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2572,14 +2572,16 @@ class OpenMPDialectLLVMIRTranslationInterface
/// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime
/// calls, or operation amendments
LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final;
};
} // namespace
LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
- Operation *op, NamedAttribute attribute,
+ Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
attribute.getName())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 5ab70280f6c818..55a6285ec87eb4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -81,7 +81,8 @@ class ROCDLDialectLLVMIRTranslationInterface
/// Attaches module-level metadata for functions marked as kernels.
LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 9f0e1f3c3bb6f6..1722d74c08b628 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -59,6 +59,113 @@ using namespace mlir::LLVM::detail;
#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
+namespace {
+/// A customized inserter for LLVM's IRBuilder that captures all LLVM IR
+/// instructions that are created for future reference.
+///
+/// This is intended to be used with the `CollectionScope` RAII object:
+///
+/// llvm::IRBuilder<..., InstructionCapturingInserter> builder;
+/// {
+/// InstructionCapturingInserter::CollectionScope scope(builder);
+/// // Call IRBuilder methods as usual.
+///
+/// // This will return a list of all instructions created by the builder,
+/// // in order of creation.
+/// builder.getInserter().getCapturedInstructions();
+/// }
+/// // This will return an empty list.
+/// builder.getInserter().getCapturedInstructions();
+///
+/// The capturing functionality is _disabled_ by default for performance
+/// consideration. It needs to be explicitly enabled, which is achieved by
+/// creating a `CollectionScope`.
+class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter {
+public:
+ /// Constructs the inserter.
+ InstructionCapturingInserter()
+ : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) {
+ if (LLVM_LIKELY(enabled))
+ capturedInstructions.push_back(instruction);
+ }) {}
+
+ /// Returns the list of LLVM IR instructions captured since the last cleanup.
+ ArrayRef<llvm::Instruction *> getCapturedInstructions() const {
+ return capturedInstructions;
+ }
+
+ /// Clears the list of captured LLVM IR instructions.
+ void clearCapturedInstructions() { capturedInstructions.clear(); }
+
+ /// RAII object enabling the capture of created LLVM IR instructions.
+ class CollectionScope {
+ public:
+ /// Creates the scope for the given inserter.
+ CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing);
+
+ /// Ends the scope.
+ ~CollectionScope();
+
+ ArrayRef<llvm::Instruction *> getCapturedInstructions() {
+ if (!inserter)
+ return {};
+ return inserter->getCapturedInstructions();
+ }
+
+ private:
+ /// Back reference to the inserter.
+ InstructionCapturingInserter *inserter = nullptr;
+
+ /// List of instructions in the inserter prior to this scope.
+ SmallVector<llvm::Instruction *> previouslyCollectedInstructions;
+
+ /// Whether the inserter was enabled prior to this scope.
+ bool wasEnabled;
+ };
+
+ /// Enable or disable the capturing mechanism.
+ void setEnabled(bool enabled = true) { this->enabled = enabled; }
+
+private:
+ /// List of captured instructions.
+ SmallVector<llvm::Instruction *> capturedInstructions;
+
+ /// Whether the collection is enabled.
+ bool enabled = false;
+};
+
+using CapturingIRBuilder =
+ llvm::IRBuilder<llvm::ConstantFolder, InstructionCapturingInserter>;
+} // namespace
+
+InstructionCapturingInserter::CollectionScope::CollectionScope(
+ llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) {
+
+ if (!isBuilderCapturing)
+ return;
+
+ auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder);
+ inserter = &capturingIRBuilder.getInserter();
+ wasEnabled = inserter->enabled;
+ if (wasEnabled)
+ previouslyCollectedInstructions.swap(inserter->capturedInstructions);
+ inserter->setEnabled(true);
+}
+
+InstructionCapturingInserter::CollectionScope::~CollectionScope() {
+ if (!inserter)
+ return;
+
+ previouslyCollectedInstructions.swap(inserter->capturedInstructions);
+ // If collection was enabled (likely in another, surrounding scope), keep
+ // the instructions collected in this scope.
+ if (wasEnabled) {
+ llvm::append_range(inserter->capturedInstructions,
+ previouslyCollectedInstructions);
+ }
+ inserter->setEnabled(wasEnabled);
+}
+
/// Translates the given data layout spec attribute to the LLVM IR data layout.
/// Only integer, float, pointer and endianness entries are currently supported.
static FailureOr<llvm::DataLayout>
@@ -631,9 +738,9 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`.
-LogicalResult
-ModuleTranslation::convertOperation(Operation &op,
- llvm::IRBuilderBase &builder) {
+LogicalResult ModuleTranslation::convertOperation(Operation &op,
+ llvm::IRBuilderBase &builder,
+ bool recordInsertions) {
const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op);
if (!opIface)
return op.emitError("cannot be converted to LLVM IR: missing "
@@ -641,11 +748,13 @@ ModuleTranslation::convertOperation(Operation &op,
"dialect for op: ")
<< op.getName();
+ InstructionCapturingInserter::CollectionScope scope(builder,
+ recordInsertions);
if (failed(opIface->convertOperation(&op, builder, *this)))
return op.emitError("LLVM Translation failed for operation: ")
<< op.getName();
- return convertDialectAttributes(&op);
+ return convertDialectAttributes(&op, scope.getCapturedInstructions());
}
/// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
@@ -655,8 +764,10 @@ ModuleTranslation::convertOperation(Operation &op,
/// been created for `bb` and included in the block mapping. Inserts new
/// instructions at the end of the block and leaves `builder` in a state
/// suitable for further insertion into the end of the block.
-LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
- llvm::IRBuilderBase &builder) {
+LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
+ bool ignoreArguments,
+ llvm::IRBuilderBase &builder,
+ bool recordInsertions) {
builder.SetInsertPoint(lookupBlock(&bb));
auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
@@ -687,7 +798,7 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
builder.SetCurrentDebugLocation(
debugTranslation->translateLoc(op.getLoc(), subprogram));
- if (failed(convertOperation(op, builder)))
+ if (failed(convertOperation(op, builder, recordInsertions)))
return failure();
// Set the branch weight metadata on the translated instruction.
@@ -844,7 +955,7 @@ LogicalResult ModuleTranslation::convertGlobals() {
}
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
- if (failed(convertDialectAttributes(op)))
+ if (failed(convertDialectAttributes(op, {})))
return failure();
// Finally, update the compile units their respective sets of global variables
@@ -997,8 +1108,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// converted before uses.
auto blocks = getTopologicallySortedBlocks(func.getBody());
for (Block *bb : blocks) {
- llvm::IRBuilder<> builder(llvmContext);
- if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
+ CapturingIRBuilder builder(llvmContext);
+ if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder,
+ /*recordInsertions=*/true)))
return failure();
}
@@ -1007,12 +1119,13 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
detail::connectPHINodes(func.getBody(), *this);
// Finally, convert dialect attributes attached to the function.
- return convertDialectAttributes(func);
+ return convertDialectAttributes(func, {});
}
-LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
+LogicalResult ModuleTranslation::convertDialectAttributes(
+ Operation *op, ArrayRef<llvm::Instruction *> instructions) {
for (NamedAttribute attribute : op->getDialectAttrs())
- if (failed(iface.amendOperation(op, attribute, *this)))
+ if (failed(iface.amendOperation(op, instructions, attribute, *this)))
return failure();
return success();
}
@@ -1134,7 +1247,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
// Do not convert external functions, but do process dialect attributes
// attached to them.
if (function.isExternal()) {
- if (failed(convertDialectAttributes(function)))
+ if (failed(convertDialectAttributes(function, {})))
return failure();
continue;
}
diff --git a/mlir/test/Target/LLVMIR/test.mlir b/mlir/test/Target/LLVMIR/test.mlir
index f48738f44f44b4..0ab1b7267d9598 100644
--- a/mlir/test/Target/LLVMIR/test.mlir
+++ b/mlir/test/Target/LLVMIR/test.mlir
@@ -16,3 +16,27 @@ module {
module attributes {test.discardable_mod_attr = true} {}
// CHECK: @sym_from_attr = external global i32
+
+// -----
+
+// CHECK-LABEL: @dialect_attr_translation
+llvm.func @dialect_attr_translation() {
+ // CHECK: ret void, !annotation ![[MD_ID:.+]]
+ llvm.return {test.add_annotation}
+}
+// CHECK: ![[MD_ID]] = !{!"annotation_from_test"}
+
+// -----
+
+// CHECK-LABEL: @dialect_attr_translation_multi
+llvm.func @dialect_attr_translation_multi(%a: i64, %b: i64, %c: i64) -> i64 {
+ // CHECK: add {{.*}}, !annotation ![[MD_ID_ADD:.+]]
+ // CHECK: mul {{.*}}, !annotation ![[MD_ID_MUL:.+]]
+ // CHECK: ret {{.*}}, !annotation ![[MD_ID_RET:.+]]
+ %ab = llvm.add %a, %b {test.add_annotation = "add"} : i64
+ %r = llvm.mul %ab, %c {test.add_annotation = "mul"} : i64
+ llvm.return {test.add_annotation = "ret"} %r : i64
+}
+// CHECK-DAG: ![[MD_ID_ADD]] = !{!"annotation_from_test: add"}
+// CHECK-DAG: ![[MD_ID_MUL]] = !{!"annotation_from_test: mul"}
+// CHECK-DAG: ![[MD_ID_RET]] = !{!"annotation_from_test: ret"}
diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
index 7110d999c8f8ae..2dd99c67c1439b 100644
--- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
@@ -32,7 +32,8 @@ class TestDialectLLVMIRTranslationInterface
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final;
LogicalResult
@@ -43,7 +44,8 @@ class TestDialectLLVMIRTranslationInterface
} // namespace
LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
- Operation *op, NamedAttribute attribute,
+ Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
attribute.getName())
@@ -70,6 +72,18 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
/*sym_visibility=*/nullptr);
}
+ return success();
+ })
+ .Case("test.add_annotation",
+ [&](Attribute attr) {
+ for (llvm::Instruction *instruction : instructions) {
+ if (auto strAttr = dyn_cast<StringAttr>(attr)) {
+ instruction->addAnnotationMetadata("annotation_from_test: " +
+ strAttr.getValue().str());
+ } else {
+ instruction->addAnnotationMetadata("annotation_from_test");
+ }
+ }
return success();
})
.Default([](Attribute) {
More information about the Mlir-commits
mailing list