[Mlir-commits] [mlir] e630a50 - [mlir][llvm] Fuse MD_access_group & MD_loop import

Christian Ulmann llvmlistbot at llvm.org
Thu Feb 9 05:43:43 PST 2023


Author: Christian Ulmann
Date: 2023-02-09T14:43:02+01:00
New Revision: e630a502230f8779bddd214094d28fef61fde866

URL: https://github.com/llvm/llvm-project/commit/e630a502230f8779bddd214094d28fef61fde866
DIFF: https://github.com/llvm/llvm-project/commit/e630a502230f8779bddd214094d28fef61fde866.diff

LOG: [mlir][llvm] Fuse MD_access_group & MD_loop import

This commit moves the importing logic of access group metadata into the
loop annotation importer. These two metadata imports can be grouped
because access groups are only used in combination with
`llvm.loop.parallel_accesses`.

As a nice side effect, this commit decouples the LoopAnnotationImporter
from the ModuleImport class.

Differential Revision: https://reviews.llvm.org/D143577

Added: 
    

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleImport.h
    mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp
    mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/test/Target/LLVMIR/Import/import-failure.ll

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 23b1fbc29dd72..3265c32372410 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -302,9 +302,6 @@ class ModuleImport {
   /// to the LLVMIR dialect TBAA operations corresponding to these
   /// nodes.
   DenseMap<const llvm::MDNode *, SymbolRefAttr> tbaaMapping;
-  /// Mapping between original LLVM access group metadata nodes and the symbol
-  /// references pointing to the imported MLIR access group operations.
-  DenseMap<const llvm::MDNode *, SymbolRefAttr> accessGroupMapping;
   /// The stateful type translator (contains named structs).
   LLVM::TypeFromLLVMIRTranslator typeTranslator;
   /// Stateful debug information importer.

diff  --git a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp
index a3218e13307e4..a3cbf2bcd47c1 100644
--- a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp
@@ -16,11 +16,9 @@ using namespace mlir::LLVM::detail;
 namespace {
 /// Helper class that keeps the state of one metadata to attribute conversion.
 struct LoopMetadataConversion {
-  LoopMetadataConversion(const llvm::MDNode *node, ModuleImport &moduleImport,
-                         Location loc,
+  LoopMetadataConversion(const llvm::MDNode *node, Location loc,
                          LoopAnnotationImporter &loopAnnotationImporter)
-      : node(node), moduleImport(moduleImport), loc(loc),
-        loopAnnotationImporter(loopAnnotationImporter),
+      : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter),
         ctx(loc->getContext()){};
   /// Converts this structs loop metadata node into a LoopAnnotationAttr.
   LoopAnnotationAttr convert();
@@ -55,7 +53,6 @@ struct LoopMetadataConversion {
 
   llvm::StringMap<const llvm::MDNode *> propertyMap;
   const llvm::MDNode *node;
-  ModuleImport &moduleImport;
   Location loc;
   LoopAnnotationImporter &loopAnnotationImporter;
   MLIRContext *ctx;
@@ -233,7 +230,7 @@ LoopMetadataConversion::lookupFollowupNode(StringRef name) {
   if (*node == nullptr)
     return LoopAnnotationAttr(nullptr);
 
-  return loopAnnotationImporter.translate(*node, loc);
+  return loopAnnotationImporter.translateLoopAnnotation(*node, loc);
 }
 
 static bool isEmptyOrNull(const Attribute attr) { return !attr; }
@@ -360,7 +357,7 @@ LoopMetadataConversion::convertParallelAccesses() {
   SmallVector<SymbolRefAttr> refs;
   for (llvm::MDNode *node : *nodes) {
     FailureOr<SmallVector<SymbolRefAttr>> accessGroups =
-        moduleImport.lookupAccessGroupAttrs(node);
+        loopAnnotationImporter.lookupAccessGroupAttrs(node);
     if (failed(accessGroups))
       return emitWarning(loc) << "could not lookup access group";
     llvm::append_range(refs, *accessGroups);
@@ -398,8 +395,9 @@ LoopAnnotationAttr LoopMetadataConversion::convert() {
       parallelAccesses);
 }
 
-LoopAnnotationAttr LoopAnnotationImporter::translate(const llvm::MDNode *node,
-                                                     Location loc) {
+LoopAnnotationAttr
+LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node,
+                                                Location loc) {
   if (!node)
     return {};
 
@@ -409,9 +407,60 @@ LoopAnnotationAttr LoopAnnotationImporter::translate(const llvm::MDNode *node,
   if (it != loopMetadataMapping.end())
     return it->getSecond();
 
-  LoopAnnotationAttr attr =
-      LoopMetadataConversion(node, moduleImport, loc, *this).convert();
+  LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert();
 
   mapLoopMetadata(node, attr);
   return attr;
 }
+
+LogicalResult LoopAnnotationImporter::translateAccessGroup(
+    const llvm::MDNode *node, Location loc, MetadataOp metadataOp) {
+  SmallVector<const llvm::MDNode *> accessGroups;
+  if (!node->getNumOperands())
+    accessGroups.push_back(node);
+  for (const llvm::MDOperand &operand : node->operands()) {
+    auto *childNode = dyn_cast<llvm::MDNode>(operand);
+    if (!childNode)
+      return emitWarning(loc)
+             << "expected access group operands to be metadata nodes";
+    accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
+  }
+
+  // Convert all entries of the access group list to access group operations.
+  for (const llvm::MDNode *accessGroup : accessGroups) {
+    if (accessGroupMapping.count(accessGroup))
+      continue;
+    // Verify the access group node is distinct and empty.
+    if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
+      return emitWarning(loc)
+             << "expected an access group node to be empty and distinct";
+
+    OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToEnd(&metadataOp.getBody().back());
+    auto groupOp = builder.create<AccessGroupMetadataOp>(
+        loc, llvm::formatv("group_{0}", accessGroupMapping.size()).str());
+    // Add a mapping from the access group node to the symbol reference pointing
+    // to the newly created operation.
+    accessGroupMapping[accessGroup] = SymbolRefAttr::get(
+        builder.getContext(), metadataOp.getSymName(),
+        FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName()));
+  }
+  return success();
+}
+
+FailureOr<SmallVector<SymbolRefAttr>>
+LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
+  // An access group node is either a single access group or an access group
+  // list.
+  SmallVector<SymbolRefAttr> accessGroups;
+  if (!node->getNumOperands())
+    accessGroups.push_back(accessGroupMapping.lookup(node));
+  for (const llvm::MDOperand &operand : node->operands()) {
+    auto *node = cast<llvm::MDNode>(operand.get());
+    accessGroups.push_back(accessGroupMapping.lookup(node));
+  }
+  // Exit if one of the access group node lookups failed.
+  if (llvm::is_contained(accessGroups, nullptr))
+    return failure();
+  return accessGroups;
+}

diff  --git a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h
index bd6f5ef350e64..5d69a63a21502 100644
--- a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h
+++ b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h
@@ -21,13 +21,28 @@ namespace mlir {
 namespace LLVM {
 namespace detail {
 
-/// A helper class that converts a `llvm.loop` metadata node into a
-/// corresponding LoopAnnotationAttr.
+/// A helper class that converts llvm.loop metadata nodes into corresponding
+/// LoopAnnotationAttrs and llvm.access.group nodes into
+/// AccessGroupMetadataOps.
 class LoopAnnotationImporter {
 public:
-  explicit LoopAnnotationImporter(ModuleImport &moduleImport)
-      : moduleImport(moduleImport) {}
-  LoopAnnotationAttr translate(const llvm::MDNode *node, Location loc);
+  explicit LoopAnnotationImporter(OpBuilder &builder) : builder(builder) {}
+  LoopAnnotationAttr translateLoopAnnotation(const llvm::MDNode *node,
+                                             Location loc);
+
+  /// Converts all LLVM access groups starting from node to MLIR access group
+  /// operations mested in the region of metadataOp. It stores a mapping from
+  /// every nested access group nod to the symbol pointing to the translated
+  /// operation. Returns success if all conversions succeed and failure
+  /// otherwise.
+  LogicalResult translateAccessGroup(const llvm::MDNode *node, Location loc,
+                                     MetadataOp metadataOp);
+
+  /// Returns the symbol references pointing to the access group operations that
+  /// map to the access group nodes starting from the access group metadata
+  /// node. Returns failure, if any of the symbol references cannot be found.
+  FailureOr<SmallVector<SymbolRefAttr>>
+  lookupAccessGroupAttrs(const llvm::MDNode *node) const;
 
 private:
   /// Returns the LLVM metadata corresponding to a llvm loop metadata attribute.
@@ -42,8 +57,11 @@ class LoopAnnotationImporter {
            "attempting to map loop options that was already mapped");
   }
 
-  ModuleImport &moduleImport;
+  OpBuilder &builder;
   DenseMap<const llvm::MDNode *, LoopAnnotationAttr> loopMetadataMapping;
+  /// Mapping between original LLVM access group metadata nodes and the symbol
+  /// references pointing to the imported MLIR access group operations.
+  DenseMap<const llvm::MDNode *, SymbolRefAttr> accessGroupMapping;
 };
 
 } // namespace detail

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index a5142f96fe0a3..992345686e716 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -255,7 +255,8 @@ ModuleImport::ModuleImport(ModuleOp mlirModule,
       iface(mlirModule->getContext()),
       typeTranslator(*mlirModule->getContext()),
       debugImporter(std::make_unique<DebugImporter>(mlirModule)),
-      loopAnnotationImporter(std::make_unique<LoopAnnotationImporter>(*this)) {
+      loopAnnotationImporter(
+          std::make_unique<LoopAnnotationImporter>(builder)) {
   builder.setInsertionPointToStart(mlirModule.getBody());
 }
 
@@ -512,35 +513,11 @@ LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) {
 
 LogicalResult
 ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) {
-  // An access group node is either access group or an access group list. Start
-  // by collecting all access groups to translate.
-  SmallVector<const llvm::MDNode *> accessGroups;
-  if (!node->getNumOperands())
-    accessGroups.push_back(node);
-  for (const llvm::MDOperand &operand : node->operands())
-    accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
-
-  // Convert all entries of the access group list to access group operations.
-  for (const llvm::MDNode *accessGroup : accessGroups) {
-    if (accessGroupMapping.count(accessGroup))
-      continue;
-    // Verify the access group node is distinct and empty.
-    Location loc = mlirModule.getLoc();
-    if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
-      return emitError(loc) << "unsupported access group node: "
-                            << diagMD(accessGroup, llvmModule.get());
-
-    MetadataOp metadataOp = getGlobalMetadataOp();
-    OpBuilder::InsertionGuard guard(builder);
-    builder.setInsertionPointToEnd(&metadataOp.getBody().back());
-    auto groupOp = builder.create<AccessGroupMetadataOp>(
-        loc, (Twine("group_") + Twine(accessGroupMapping.size())).str());
-    // Add a mapping from the access group node to the symbol reference pointing
-    // to the newly created operation.
-    accessGroupMapping[accessGroup] = SymbolRefAttr::get(
-        builder.getContext(), metadataOp.getSymName(),
-        FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName()));
-  }
+  Location loc = mlirModule.getLoc();
+  if (failed(loopAnnotationImporter->translateAccessGroup(
+          node, loc, getGlobalMetadataOp())))
+    return emitError(loc) << "unsupported access group node: "
+                          << diagMD(node, llvmModule.get());
   return success();
 }
 
@@ -1587,25 +1564,13 @@ LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
 
 FailureOr<SmallVector<SymbolRefAttr>>
 ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
-  // An access group node is either a single access group or an access group
-  // list.
-  SmallVector<SymbolRefAttr> accessGroups;
-  if (!node->getNumOperands())
-    accessGroups.push_back(accessGroupMapping.lookup(node));
-  for (const llvm::MDOperand &operand : node->operands()) {
-    auto *node = cast<llvm::MDNode>(operand.get());
-    accessGroups.push_back(accessGroupMapping.lookup(node));
-  }
-  // Exit if one of the access group node lookups failed.
-  if (llvm::is_contained(accessGroups, nullptr))
-    return failure();
-  return accessGroups;
+  return loopAnnotationImporter->lookupAccessGroupAttrs(node);
 }
 
 LoopAnnotationAttr
 ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
                                           Location loc) const {
-  return loopAnnotationImporter->translate(node, loc);
+  return loopAnnotationImporter->translateLoopAnnotation(node, loc);
 }
 
 OwningOpRef<ModuleOp>

diff  --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index ea35b92164816..12a605a72d344 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -241,7 +241,8 @@ define dso_local void @tbaa(ptr %0) {
 ; // -----
 
 ; CHECK:      import-failure.ll
-; CHECK-SAME: error: unsupported access group node: !0 = !{}
+; CHECK-SAME: warning: expected an access group node to be empty and distinct
+; CHECK:      error: unsupported access group node: !0 = !{}
 define void @access_group(ptr %arg1) {
   %1 = load i32, ptr %arg1, !llvm.access.group !0
   ret void
@@ -252,7 +253,8 @@ define void @access_group(ptr %arg1) {
 ; // -----
 
 ; CHECK:      import-failure.ll
-; CHECK-SAME: error: unsupported access group node: !1 = distinct !{!"unsupported access group"}
+; CHECK-SAME: warning: expected an access group node to be empty and distinct
+; CHECK:      error: unsupported access group node: !0 = !{!1}
 define void @access_group(ptr %arg1) {
   %1 = load i32, ptr %arg1, !llvm.access.group !0
   ret void
@@ -263,6 +265,18 @@ define void @access_group(ptr %arg1) {
 
 ; // -----
 
+; CHECK:      import-failure.ll
+; CHECK-SAME: warning: expected access group operands to be metadata nodes
+; CHECK:      error: unsupported access group node: !0 = !{i1 false}
+define void @access_group(ptr %arg1) {
+  %1 = load i32, ptr %arg1, !llvm.access.group !0
+  ret void
+}
+
+!0 = !{i1 false}
+
+; // -----
+
 ; CHECK:      import-failure.ll
 ; CHECK-SAME: warning: expected all loop properties to be either debug locations or metadata nodes
 ; CHECK:      import-failure.ll


        


More information about the Mlir-commits mailing list