[Mlir-commits] [mlir] 87a0479 - [mlir][llvm] Fuse access_group & loop export (NFC)

Christian Ulmann llvmlistbot at llvm.org
Fri Feb 17 06:33:39 PST 2023


Author: Christian Ulmann
Date: 2023-02-17T15:31:21+01:00
New Revision: 87a0479538fe4fad1cbbf729ee6e1ee35326f093

URL: https://github.com/llvm/llvm-project/commit/87a0479538fe4fad1cbbf729ee6e1ee35326f093
DIFF: https://github.com/llvm/llvm-project/commit/87a0479538fe4fad1cbbf729ee6e1ee35326f093.diff

LOG: [mlir][llvm] Fuse access_group & loop export (NFC)

This commit moves the access group translation into the
LoopAnnotationTranslation class as these two metadata kinds only appear
together.

Drops the access group cleanup from `ModuleTranslation::forgetMapping`
as this is only used on function regions. Access groups only appear in the
region of a global metadata operation and will thus not be cleaned here.

Analogous to https://reviews.llvm.org/D143577

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D144253

Added: 
    

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
    mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 2b08d96b680b9..faca8fc5e4fb4 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -120,11 +120,6 @@ class ModuleTranslation {
   /// in these blocks.
   void forgetMapping(Region &region);
 
-  /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
-  /// LLVM dialect access group operation.
-  llvm::MDNode *getAccessGroup(Operation *op,
-                               SymbolRefAttr accessGroupRef) const;
-
   /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
   /// LLVM dialect alias scope operation
   llvm::MDNode *getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const;
@@ -332,11 +327,6 @@ class ModuleTranslation {
   /// values after all operations are converted.
   DenseMap<Operation *, llvm::Instruction *> branchMapping;
 
-  /// Mapping from an access group metadata operation to its LLVM metadata.
-  /// This map is populated on module entry and is used to annotate loops (as
-  /// identified via their branches) and contained memory accesses.
-  DenseMap<Operation *, llvm::MDNode *> accessGroupMetadataMapping;
-
   /// Mapping from an alias scope metadata operation to its LLVM metadata.
   /// This map is populated on module entry.
   DenseMap<Operation *, llvm::MDNode *> aliasScopeMetadataMapping;

diff  --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
index 5f27f97f2fd45..4a535dcc3a8d3 100644
--- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
@@ -15,12 +15,11 @@ using namespace mlir::LLVM::detail;
 namespace {
 /// Helper class that keeps the state of one attribute to metadata conversion.
 struct LoopAnnotationConversion {
-  LoopAnnotationConversion(LoopAnnotationAttr attr,
-                           ModuleTranslation &moduleTranslation, Operation *op,
-                           LoopAnnotationTranslation &loopAnnotationTranslation)
-      : attr(attr), moduleTranslation(moduleTranslation), op(op),
-        loopAnnotationTranslation(loopAnnotationTranslation),
-        ctx(moduleTranslation.getLLVMContext()) {}
+  LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op,
+                           LoopAnnotationTranslation &loopAnnotationTranslation,
+                           llvm::LLVMContext &ctx)
+      : attr(attr), op(op),
+        loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {}
 
   /// Converts this struct's loop annotation into a corresponding LLVMIR
   /// metadata representation.
@@ -46,7 +45,6 @@ struct LoopAnnotationConversion {
   void convertLoopOptions(LoopUnswitchAttr options);
 
   LoopAnnotationAttr attr;
-  ModuleTranslation &moduleTranslation;
   Operation *op;
   LoopAnnotationTranslation &loopAnnotationTranslation;
   llvm::LLVMContext &ctx;
@@ -95,7 +93,8 @@ void LoopAnnotationConversion::convertFollowupNode(StringRef name,
   if (!attr)
     return;
 
-  llvm::MDNode *node = loopAnnotationTranslation.translate(attr, op);
+  llvm::MDNode *node =
+      loopAnnotationTranslation.translateLoopAnnotation(attr, op);
 
   metadataNodes.push_back(
       llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), node}));
@@ -225,7 +224,7 @@ llvm::MDNode *LoopAnnotationConversion::convert() {
         llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
     for (SymbolRefAttr accessGroupRef : parallelAccessGroups)
       parallelAccess.push_back(
-          moduleTranslation.getAccessGroup(op, accessGroupRef));
+          loopAnnotationTranslation.getAccessGroup(op, accessGroupRef));
     metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
   }
 
@@ -236,7 +235,8 @@ llvm::MDNode *LoopAnnotationConversion::convert() {
   return loopMD;
 }
 
-llvm::MDNode *LoopAnnotationTranslation::translate(LoopAnnotationAttr attr,
+llvm::MDNode *
+LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr,
                                                    Operation *op) {
   if (!attr)
     return nullptr;
@@ -246,9 +246,47 @@ llvm::MDNode *LoopAnnotationTranslation::translate(LoopAnnotationAttr attr,
     return loopMD;
 
   loopMD =
-      LoopAnnotationConversion(attr, moduleTranslation, op, *this).convert();
+      LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext())
+          .convert();
   // Store a map from this Attribute to the LLVM metadata in case we
   // encounter it again.
   mapLoopMetadata(attr, loopMD);
   return loopMD;
 }
+
+LogicalResult LoopAnnotationTranslation::createAccessGroupMetadata() {
+  mlirModule->walk([&](LLVM::MetadataOp metadatas) {
+    metadatas.walk([&](LLVM::AccessGroupMetadataOp op) {
+      llvm::MDNode *accessGroup =
+          llvm::MDNode::getDistinct(llvmModule.getContext(), {});
+      accessGroupMetadataMapping.insert({op, accessGroup});
+    });
+  });
+  return success();
+}
+
+llvm::MDNode *
+LoopAnnotationTranslation::getAccessGroup(Operation *op,
+                                          SymbolRefAttr accessGroupRef) const {
+  auto metadataName = accessGroupRef.getRootReference();
+  auto accessGroupName = accessGroupRef.getLeafReference();
+  auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+      op->getParentOp(), metadataName);
+  auto *accessGroupOp =
+      SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
+  return accessGroupMetadataMapping.lookup(accessGroupOp);
+}
+
+llvm::MDNode *
+LoopAnnotationTranslation::getAccessGroups(Operation *op,
+                                           ArrayAttr accessGroupRefs) const {
+  if (!accessGroupRefs || accessGroupRefs.empty())
+    return nullptr;
+
+  SmallVector<llvm::Metadata *> groupMDs;
+  for (SymbolRefAttr groupRef : accessGroupRefs.getAsRange<SymbolRefAttr>())
+    groupMDs.push_back(getAccessGroup(op, groupRef));
+  if (groupMDs.size() == 1)
+    return llvm::cast<llvm::MDNode>(groupMDs.front());
+  return llvm::MDNode::get(llvmModule.getContext(), groupMDs);
+}

diff  --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h
index 0bbd5442510fe..4de54f3998300 100644
--- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h
+++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h
@@ -21,14 +21,28 @@ namespace mlir {
 namespace LLVM {
 namespace detail {
 
-/// A helper class that converts a LoopAnnotationAttr into a corresponding
-/// llvm::MDNode.
+/// A helper class that converts LoopAnnotationAttrs and AccessGroupMetadataOps
+/// into a corresponding llvm::MDNodes.
 class LoopAnnotationTranslation {
 public:
-  LoopAnnotationTranslation(LLVM::ModuleTranslation &moduleTranslation)
-      : moduleTranslation(moduleTranslation) {}
+  LoopAnnotationTranslation(Operation *mlirModule, llvm::Module &llvmModule)
+      : mlirModule(mlirModule), llvmModule(llvmModule) {}
 
-  llvm::MDNode *translate(LoopAnnotationAttr attr, Operation *op);
+  llvm::MDNode *translateLoopAnnotation(LoopAnnotationAttr attr, Operation *op);
+
+  /// Traverses the global access group metadata operation in the `mlirModule`
+  /// and creates corresponding LLVM metadata nodes.
+  LogicalResult createAccessGroupMetadata();
+
+  /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+  /// LLVM dialect access group operation.
+  llvm::MDNode *getAccessGroup(Operation *op,
+                               SymbolRefAttr accessGroupRef) const;
+
+  /// Returns the LLVM metadata corresponding to a list of symbol reference to
+  /// an mlir LLVM dialect access group operation. Returns nullptr if
+  /// `accessGroupRefs` is null or empty.
+  llvm::MDNode *getAccessGroups(Operation *op, ArrayAttr accessGroupRefs) const;
 
 private:
   /// Returns the LLVM metadata corresponding to a llvm loop metadata attribute.
@@ -47,7 +61,13 @@ class LoopAnnotationTranslation {
   /// The metadata is attached to Latch block branches with this attribute.
   DenseMap<Attribute, llvm::MDNode *> loopMetadataMapping;
 
-  LLVM::ModuleTranslation &moduleTranslation;
+  /// Mapping from an access group metadata operation to its LLVM metadata.
+  /// This map is populated on module entry and is used to annotate loops (as
+  /// identified via their branches) and contained memory accesses.
+  DenseMap<Operation *, llvm::MDNode *> accessGroupMetadataMapping;
+
+  Operation *mlirModule;
+  llvm::Module &llvmModule;
 };
 
 } // namespace detail

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 232b3d8b22160..04eddde310cf1 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -421,8 +421,8 @@ ModuleTranslation::ModuleTranslation(Operation *module,
     : mlirModule(module), llvmModule(std::move(llvmModule)),
       debugTranslation(
           std::make_unique<DebugTranslation>(module, *this->llvmModule)),
-      loopAnnotationTranslation(
-          std::make_unique<LoopAnnotationTranslation>(*this)),
+      loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>(
+          module, *this->llvmModule)),
       typeTranslator(this->llvmModule->getContext()),
       iface(module->getContext()) {
   assert(satisfiesLLVMModule(mlirModule) &&
@@ -449,7 +449,6 @@ void ModuleTranslation::forgetMapping(Region &region) {
           branchMapping.erase(&op);
         if (isa<LLVM::GlobalOp>(op))
           globalsMapping.erase(&op);
-        accessGroupMetadataMapping.erase(&op);
         llvm::append_range(
             toProcess,
             llvm::map_range(op.getRegions(), [](Region &r) { return &r; }));
@@ -994,47 +993,16 @@ LogicalResult ModuleTranslation::convertFunctions() {
   return success();
 }
 
-llvm::MDNode *
-ModuleTranslation::getAccessGroup(Operation *op,
-                                  SymbolRefAttr accessGroupRef) const {
-  auto metadataName = accessGroupRef.getRootReference();
-  auto accessGroupName = accessGroupRef.getLeafReference();
-  auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
-      op->getParentOp(), metadataName);
-  auto *accessGroupOp =
-      SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
-  return accessGroupMetadataMapping.lookup(accessGroupOp);
-}
-
 LogicalResult ModuleTranslation::createAccessGroupMetadata() {
-  mlirModule->walk([&](LLVM::MetadataOp metadatas) {
-    metadatas.walk([&](LLVM::AccessGroupMetadataOp op) {
-      llvm::LLVMContext &ctx = llvmModule->getContext();
-      llvm::MDNode *accessGroup = llvm::MDNode::getDistinct(ctx, {});
-      accessGroupMetadataMapping.insert({op, accessGroup});
-    });
-  });
-  return success();
+  return loopAnnotationTranslation->createAccessGroupMetadata();
 }
 
 void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
                                                 llvm::Instruction *inst) {
   auto populateGroupsMetadata = [&](ArrayAttr groupRefs) {
-    if (!groupRefs || groupRefs.empty())
-      return;
-
-    llvm::Module *module = inst->getModule();
-    SmallVector<llvm::Metadata *> groupMDs;
-    for (SymbolRefAttr groupRef : groupRefs.getAsRange<SymbolRefAttr>())
-      groupMDs.push_back(getAccessGroup(op, groupRef));
-
-    llvm::MDNode *node = nullptr;
-    if (groupMDs.size() == 1)
-      node = llvm::cast<llvm::MDNode>(groupMDs.front());
-    else if (groupMDs.size() >= 2)
-      node = llvm::MDNode::get(module->getContext(), groupMDs);
-
-    inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
+    if (llvm::MDNode *node =
+            loopAnnotationTranslation->getAccessGroups(op, groupRefs))
+      inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
   };
 
   auto groupRefs =
@@ -1250,7 +1218,8 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
               [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); });
   if (!attr)
     return;
-  llvm::MDNode *loopMD = loopAnnotationTranslation->translate(attr, op);
+  llvm::MDNode *loopMD =
+      loopAnnotationTranslation->translateLoopAnnotation(attr, op);
   inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
 }
 


        


More information about the Mlir-commits mailing list