[Mlir-commits] [mlir] 87a0479 - [mlir][llvm] Fuse access_group & loop export (NFC)
Christian Ulmann
llvmlistbot at llvm.org
Fri Feb 17 06:33:39 PST 2023
Author: Christian Ulmann
Date: 2023-02-17T15:31:21+01:00
New Revision: 87a0479538fe4fad1cbbf729ee6e1ee35326f093
URL: https://github.com/llvm/llvm-project/commit/87a0479538fe4fad1cbbf729ee6e1ee35326f093
DIFF: https://github.com/llvm/llvm-project/commit/87a0479538fe4fad1cbbf729ee6e1ee35326f093.diff
LOG: [mlir][llvm] Fuse access_group & loop export (NFC)
This commit moves the access group translation into the
LoopAnnotationTranslation class as these two metadata kinds only appear
together.
Drops the access group cleanup from `ModuleTranslation::forgetMapping`
as this is only used on function regions. Access groups only appear in the
region of a global metadata operation and will thus not be cleaned here.
Analogous to https://reviews.llvm.org/D143577
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D144253
Added:
Modified:
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 2b08d96b680b9..faca8fc5e4fb4 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -120,11 +120,6 @@ class ModuleTranslation {
/// in these blocks.
void forgetMapping(Region ®ion);
- /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
- /// LLVM dialect access group operation.
- llvm::MDNode *getAccessGroup(Operation *op,
- SymbolRefAttr accessGroupRef) const;
-
/// Returns the LLVM metadata corresponding to a symbol reference to an mlir
/// LLVM dialect alias scope operation
llvm::MDNode *getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const;
@@ -332,11 +327,6 @@ class ModuleTranslation {
/// values after all operations are converted.
DenseMap<Operation *, llvm::Instruction *> branchMapping;
- /// Mapping from an access group metadata operation to its LLVM metadata.
- /// This map is populated on module entry and is used to annotate loops (as
- /// identified via their branches) and contained memory accesses.
- DenseMap<Operation *, llvm::MDNode *> accessGroupMetadataMapping;
-
/// Mapping from an alias scope metadata operation to its LLVM metadata.
/// This map is populated on module entry.
DenseMap<Operation *, llvm::MDNode *> aliasScopeMetadataMapping;
diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
index 5f27f97f2fd45..4a535dcc3a8d3 100644
--- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
@@ -15,12 +15,11 @@ using namespace mlir::LLVM::detail;
namespace {
/// Helper class that keeps the state of one attribute to metadata conversion.
struct LoopAnnotationConversion {
- LoopAnnotationConversion(LoopAnnotationAttr attr,
- ModuleTranslation &moduleTranslation, Operation *op,
- LoopAnnotationTranslation &loopAnnotationTranslation)
- : attr(attr), moduleTranslation(moduleTranslation), op(op),
- loopAnnotationTranslation(loopAnnotationTranslation),
- ctx(moduleTranslation.getLLVMContext()) {}
+ LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op,
+ LoopAnnotationTranslation &loopAnnotationTranslation,
+ llvm::LLVMContext &ctx)
+ : attr(attr), op(op),
+ loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {}
/// Converts this struct's loop annotation into a corresponding LLVMIR
/// metadata representation.
@@ -46,7 +45,6 @@ struct LoopAnnotationConversion {
void convertLoopOptions(LoopUnswitchAttr options);
LoopAnnotationAttr attr;
- ModuleTranslation &moduleTranslation;
Operation *op;
LoopAnnotationTranslation &loopAnnotationTranslation;
llvm::LLVMContext &ctx;
@@ -95,7 +93,8 @@ void LoopAnnotationConversion::convertFollowupNode(StringRef name,
if (!attr)
return;
- llvm::MDNode *node = loopAnnotationTranslation.translate(attr, op);
+ llvm::MDNode *node =
+ loopAnnotationTranslation.translateLoopAnnotation(attr, op);
metadataNodes.push_back(
llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), node}));
@@ -225,7 +224,7 @@ llvm::MDNode *LoopAnnotationConversion::convert() {
llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
for (SymbolRefAttr accessGroupRef : parallelAccessGroups)
parallelAccess.push_back(
- moduleTranslation.getAccessGroup(op, accessGroupRef));
+ loopAnnotationTranslation.getAccessGroup(op, accessGroupRef));
metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
}
@@ -236,7 +235,8 @@ llvm::MDNode *LoopAnnotationConversion::convert() {
return loopMD;
}
-llvm::MDNode *LoopAnnotationTranslation::translate(LoopAnnotationAttr attr,
+llvm::MDNode *
+LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr,
Operation *op) {
if (!attr)
return nullptr;
@@ -246,9 +246,47 @@ llvm::MDNode *LoopAnnotationTranslation::translate(LoopAnnotationAttr attr,
return loopMD;
loopMD =
- LoopAnnotationConversion(attr, moduleTranslation, op, *this).convert();
+ LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext())
+ .convert();
// Store a map from this Attribute to the LLVM metadata in case we
// encounter it again.
mapLoopMetadata(attr, loopMD);
return loopMD;
}
+
+LogicalResult LoopAnnotationTranslation::createAccessGroupMetadata() {
+ mlirModule->walk([&](LLVM::MetadataOp metadatas) {
+ metadatas.walk([&](LLVM::AccessGroupMetadataOp op) {
+ llvm::MDNode *accessGroup =
+ llvm::MDNode::getDistinct(llvmModule.getContext(), {});
+ accessGroupMetadataMapping.insert({op, accessGroup});
+ });
+ });
+ return success();
+}
+
+llvm::MDNode *
+LoopAnnotationTranslation::getAccessGroup(Operation *op,
+ SymbolRefAttr accessGroupRef) const {
+ auto metadataName = accessGroupRef.getRootReference();
+ auto accessGroupName = accessGroupRef.getLeafReference();
+ auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+ op->getParentOp(), metadataName);
+ auto *accessGroupOp =
+ SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
+ return accessGroupMetadataMapping.lookup(accessGroupOp);
+}
+
+llvm::MDNode *
+LoopAnnotationTranslation::getAccessGroups(Operation *op,
+ ArrayAttr accessGroupRefs) const {
+ if (!accessGroupRefs || accessGroupRefs.empty())
+ return nullptr;
+
+ SmallVector<llvm::Metadata *> groupMDs;
+ for (SymbolRefAttr groupRef : accessGroupRefs.getAsRange<SymbolRefAttr>())
+ groupMDs.push_back(getAccessGroup(op, groupRef));
+ if (groupMDs.size() == 1)
+ return llvm::cast<llvm::MDNode>(groupMDs.front());
+ return llvm::MDNode::get(llvmModule.getContext(), groupMDs);
+}
diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h
index 0bbd5442510fe..4de54f3998300 100644
--- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h
+++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h
@@ -21,14 +21,28 @@ namespace mlir {
namespace LLVM {
namespace detail {
-/// A helper class that converts a LoopAnnotationAttr into a corresponding
-/// llvm::MDNode.
+/// A helper class that converts LoopAnnotationAttrs and AccessGroupMetadataOps
+/// into a corresponding llvm::MDNodes.
class LoopAnnotationTranslation {
public:
- LoopAnnotationTranslation(LLVM::ModuleTranslation &moduleTranslation)
- : moduleTranslation(moduleTranslation) {}
+ LoopAnnotationTranslation(Operation *mlirModule, llvm::Module &llvmModule)
+ : mlirModule(mlirModule), llvmModule(llvmModule) {}
- llvm::MDNode *translate(LoopAnnotationAttr attr, Operation *op);
+ llvm::MDNode *translateLoopAnnotation(LoopAnnotationAttr attr, Operation *op);
+
+ /// Traverses the global access group metadata operation in the `mlirModule`
+ /// and creates corresponding LLVM metadata nodes.
+ LogicalResult createAccessGroupMetadata();
+
+ /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+ /// LLVM dialect access group operation.
+ llvm::MDNode *getAccessGroup(Operation *op,
+ SymbolRefAttr accessGroupRef) const;
+
+ /// Returns the LLVM metadata corresponding to a list of symbol reference to
+ /// an mlir LLVM dialect access group operation. Returns nullptr if
+ /// `accessGroupRefs` is null or empty.
+ llvm::MDNode *getAccessGroups(Operation *op, ArrayAttr accessGroupRefs) const;
private:
/// Returns the LLVM metadata corresponding to a llvm loop metadata attribute.
@@ -47,7 +61,13 @@ class LoopAnnotationTranslation {
/// The metadata is attached to Latch block branches with this attribute.
DenseMap<Attribute, llvm::MDNode *> loopMetadataMapping;
- LLVM::ModuleTranslation &moduleTranslation;
+ /// Mapping from an access group metadata operation to its LLVM metadata.
+ /// This map is populated on module entry and is used to annotate loops (as
+ /// identified via their branches) and contained memory accesses.
+ DenseMap<Operation *, llvm::MDNode *> accessGroupMetadataMapping;
+
+ Operation *mlirModule;
+ llvm::Module &llvmModule;
};
} // namespace detail
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 232b3d8b22160..04eddde310cf1 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -421,8 +421,8 @@ ModuleTranslation::ModuleTranslation(Operation *module,
: mlirModule(module), llvmModule(std::move(llvmModule)),
debugTranslation(
std::make_unique<DebugTranslation>(module, *this->llvmModule)),
- loopAnnotationTranslation(
- std::make_unique<LoopAnnotationTranslation>(*this)),
+ loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>(
+ module, *this->llvmModule)),
typeTranslator(this->llvmModule->getContext()),
iface(module->getContext()) {
assert(satisfiesLLVMModule(mlirModule) &&
@@ -449,7 +449,6 @@ void ModuleTranslation::forgetMapping(Region ®ion) {
branchMapping.erase(&op);
if (isa<LLVM::GlobalOp>(op))
globalsMapping.erase(&op);
- accessGroupMetadataMapping.erase(&op);
llvm::append_range(
toProcess,
llvm::map_range(op.getRegions(), [](Region &r) { return &r; }));
@@ -994,47 +993,16 @@ LogicalResult ModuleTranslation::convertFunctions() {
return success();
}
-llvm::MDNode *
-ModuleTranslation::getAccessGroup(Operation *op,
- SymbolRefAttr accessGroupRef) const {
- auto metadataName = accessGroupRef.getRootReference();
- auto accessGroupName = accessGroupRef.getLeafReference();
- auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
- op->getParentOp(), metadataName);
- auto *accessGroupOp =
- SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
- return accessGroupMetadataMapping.lookup(accessGroupOp);
-}
-
LogicalResult ModuleTranslation::createAccessGroupMetadata() {
- mlirModule->walk([&](LLVM::MetadataOp metadatas) {
- metadatas.walk([&](LLVM::AccessGroupMetadataOp op) {
- llvm::LLVMContext &ctx = llvmModule->getContext();
- llvm::MDNode *accessGroup = llvm::MDNode::getDistinct(ctx, {});
- accessGroupMetadataMapping.insert({op, accessGroup});
- });
- });
- return success();
+ return loopAnnotationTranslation->createAccessGroupMetadata();
}
void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
llvm::Instruction *inst) {
auto populateGroupsMetadata = [&](ArrayAttr groupRefs) {
- if (!groupRefs || groupRefs.empty())
- return;
-
- llvm::Module *module = inst->getModule();
- SmallVector<llvm::Metadata *> groupMDs;
- for (SymbolRefAttr groupRef : groupRefs.getAsRange<SymbolRefAttr>())
- groupMDs.push_back(getAccessGroup(op, groupRef));
-
- llvm::MDNode *node = nullptr;
- if (groupMDs.size() == 1)
- node = llvm::cast<llvm::MDNode>(groupMDs.front());
- else if (groupMDs.size() >= 2)
- node = llvm::MDNode::get(module->getContext(), groupMDs);
-
- inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
+ if (llvm::MDNode *node =
+ loopAnnotationTranslation->getAccessGroups(op, groupRefs))
+ inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
};
auto groupRefs =
@@ -1250,7 +1218,8 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
[](auto branchOp) { return branchOp.getLoopAnnotationAttr(); });
if (!attr)
return;
- llvm::MDNode *loopMD = loopAnnotationTranslation->translate(attr, op);
+ llvm::MDNode *loopMD =
+ loopAnnotationTranslation->translateLoopAnnotation(attr, op);
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
}
More information about the Mlir-commits
mailing list