[Mlir-commits] [mlir] 4a2930f - [mlir] Add loop codegen options to some LLVM dialect ops.

Alex Zinenko llvmlistbot at llvm.org
Thu Mar 4 00:02:07 PST 2021


Author: Arpith C. Jacob
Date: 2021-03-04T09:01:57+01:00
New Revision: 4a2930f4950dbeacaf4da6fe9445215934296cce

URL: https://github.com/llvm/llvm-project/commit/4a2930f4950dbeacaf4da6fe9445215934296cce
DIFF: https://github.com/llvm/llvm-project/commit/4a2930f4950dbeacaf4da6fe9445215934296cce.diff

LOG: [mlir] Add loop codegen options to some LLVM dialect ops.

Add a Loop Option attribute and generate llvm metadata attached to
branch instructions to control code generation.

Reviewed By: ftynse, mehdi_amini

Differential Revision: https://reviews.llvm.org/D96820

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Target/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 22ff1517f77b..ac5b5907bf82 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -48,6 +48,7 @@ namespace detail {
 struct LLVMTypeStorage;
 struct LLVMDialectImpl;
 struct BitmaskEnumStorage;
+struct LoopOptionAttrStorage;
 } // namespace detail
 
 /// An attribute that specifies LLVM instruction fastmath flags.
@@ -64,6 +65,38 @@ class FMFAttr : public Attribute::AttrBase<FMFAttr, Attribute,
   static Attribute parse(DialectAsmParser &parser);
 };
 
+/// An attribute that specifies LLVM loop codegen options.
+class LoopOptionAttr
+    : public Attribute::AttrBase<LoopOptionAttr, Attribute,
+                                 detail::LoopOptionAttrStorage> {
+public:
+  using Base::Base;
+
+  /// Specifies the llvm.loop.unroll.disable metadata.
+  static LoopOptionAttr getDisableUnroll(MLIRContext *context,
+                                         bool disable = true);
+
+  /// Specifies the llvm.licm.disable metadata.
+  static LoopOptionAttr getDisableLICM(MLIRContext *context,
+                                       bool disable = true);
+
+  /// Specifies the llvm.loop.interleave.count metadata.
+  static LoopOptionAttr getInterleaveCount(MLIRContext *context, int32_t count);
+
+  /// Returns the loop option, e.g. parallel_access.
+  LoopOptionCase getCase() const;
+
+  /// Returns if the loop option is activated. Only valid for boolean options.
+  bool getBool() const;
+
+  /// Returns the integer value associated with a loop option. Only valid for
+  /// integer options.
+  int32_t getInt() const;
+
+  void print(DialectAsmPrinter &p) const;
+  static Attribute parse(DialectAsmParser &parser);
+};
+
 } // namespace LLVM
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 541f7ebfadfa..f0b4c69b6ae6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -32,6 +32,9 @@ def LLVM_Dialect : Dialect {
     static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; }
     static StringRef getAlignAttrName() { return "llvm.align"; }
     static StringRef getNoAliasAttrName() { return "llvm.noalias"; }
+    static StringRef getLoopAttrName() { return "llvm.loop"; }
+    static StringRef getParallelAccessAttrName() { return "parallel_access"; }
+    static StringRef getLoopOptionsAttrName() { return "options"; }
 
     /// Verifies if the given string is a well-formed data layout descriptor.
     /// Uses `reportError` to report errors.

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 076b8ed96e4e..956c5551ab78 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -47,6 +47,18 @@ def LLVM_FMFAttr : DialectAttr<
           "::mlir::LLVM::FMFAttr::get($0, $_builder.getContext())";
 }
 
+def LOptDisableUnroll : I32EnumAttrCase<"disable_unroll", 1>;
+def LOptDisableLICM : I32EnumAttrCase<"disable_licm", 2>;
+def LOptInterleaveCount : I32EnumAttrCase<"interleave_count", 3>;
+
+def LoopOptionCase : I32EnumAttr<
+    "LoopOptionCase",
+    "LLVM loop option",
+    [LOptDisableUnroll, LOptDisableLICM, LOptInterleaveCount
+    ]> {
+  let cppNamespace = "::mlir::LLVM";
+}
+
 class LLVM_Builder<string builder> {
   string llvmBuilder = builder;
 }
@@ -827,6 +839,48 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof"> {
   let verifier = "return ::verify(*this);";
 }
 
+def LLVM_MetadataOp : LLVM_Op<"metadata", [
+   NoRegionArguments, SymbolTable, Symbol
+]> {
+  let arguments = (ins
+    SymbolNameAttr:$sym_name
+  );
+  let summary = "LLVM dialect metadata.";
+  let description = [{
+    llvm.metadata op defines one or more metadata nodes. Currently the
+    llvm.access_group metadata op is supported.
+
+    Example:
+      llvm.metadata @metadata {
+        llvm.access_group @group1
+        llvm.access_group @group2
+        llvm.return
+      }
+  }];
+  let regions = (region SizedRegion<1>:$body);
+  let assemblyFormat = "$sym_name attr-dict-with-keyword $body";
+}
+
+def LLVM_AccessGroupMetadataOp : LLVM_Op<"access_group", [
+  HasParent<"MetadataOp">, Symbol
+]> {
+  let arguments = (ins
+    SymbolNameAttr:$sym_name
+  );
+  let summary = "LLVM dialect access group metadata.";
+  let description = [{
+    Defines an access group metadata that can be attached to any instruction
+    that potentially accesses memory. The access group may be attached to a
+    memory accessing instruction via the `llvm.access.group` metadata and
+    a branch instruction in the loop latch block via the
+    `llvm.loop.parallel_accesses` metadata.
+
+    See the following link for more details:
+    https://llvm.org/docs/LangRef.html#llvm-access-group-metadata
+  }];
+  let assemblyFormat = "$sym_name attr-dict";
+}
+
 def LLVM_GlobalOp : LLVM_Op<"mlir.global",
     [IsolatedFromAbove, SingleBlockImplicitTerminator<"ReturnOp">, Symbol]> {
   let arguments = (ins

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 698641deb99d..748268575f86 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -110,6 +110,24 @@ class ModuleTranslation {
     return branchMapping.lookup(op);
   }
 
+  /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
+  /// dialect access group operation.
+  llvm::MDNode *getAccessGroup(Operation &opInst,
+                               SymbolRefAttr accessGroupRef) const;
+
+  /// Returns the LLVM metadata corresponding to a llvm loop's codegen
+  /// options attribute.
+  llvm::MDNode *lookupLoopOptionsMetadata(Attribute options) const {
+    return loopOptionsMetadataMapping.lookup(options);
+  }
+
+  void mapLoopOptionsMetadata(Attribute options, llvm::MDNode *metadata) {
+    auto result = loopOptionsMetadataMapping.try_emplace(options, metadata);
+    (void)result;
+    assert(result.second &&
+           "attempting to map loop options that was already mapped");
+  }
+
   /// Converts the type from MLIR LLVM dialect to LLVM.
   llvm::Type *convertType(Type type);
 
@@ -167,6 +185,10 @@ class ModuleTranslation {
   LogicalResult convertGlobals();
   LogicalResult convertOneFunction(LLVMFuncOp func);
 
+  /// Process access_group LLVM Metadata operations and create LLVM
+  /// metadata nodes.
+  LogicalResult createAccessGroupMetadata();
+
   /// Translates dialect attributes attached to the given operation.
   LogicalResult convertDialectAttributes(Operation *op);
 
@@ -198,6 +220,16 @@ class ModuleTranslation {
   /// they are converted to. This allows for connecting PHI nodes to the source
   /// values after all operations are converted.
   DenseMap<Operation *, llvm::Instruction *> branchMapping;
+
+  /// Mapping from an access group metadata optation to its LLVM metadata.
+  /// This map is populated on module entry and is used to annotate loops (as
+  /// identified via their branches) and contained memory accesses.
+  DenseMap<Operation *, llvm::MDNode *> accessGroupMetadataMapping;
+
+  /// Mapping from an attribute describing loop codegen options to its LLVM
+  /// metadata. The metadata is attached to Latch block branches with this
+  /// attribute.
+  DenseMap<Attribute, llvm::MDNode *> loopOptionsMetadataMapping;
 };
 
 namespace detail {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d63e7753f93e..0538862b56e1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -56,6 +56,26 @@ struct BitmaskEnumStorage : public AttributeStorage {
 
   KeyTy value = 0;
 };
+
+struct LoopOptionAttrStorage : public AttributeStorage {
+  using KeyTy = std::pair<uint64_t, int32_t>;
+
+  explicit LoopOptionAttrStorage(uint64_t option, int32_t value)
+      : option(option), value(value) {}
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(option, value);
+  }
+
+  static LoopOptionAttrStorage *
+  construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<LoopOptionAttrStorage>())
+        LoopOptionAttrStorage(key.first, key.second);
+  }
+
+  uint64_t option;
+  int32_t value;
+};
 } // namespace detail
 } // namespace LLVM
 } // namespace mlir
@@ -2158,7 +2178,7 @@ static LogicalResult verify(FenceOp &op) {
 //===----------------------------------------------------------------------===//
 
 void LLVMDialect::initialize() {
-  addAttributes<FMFAttr>();
+  addAttributes<FMFAttr, LoopOptionAttr>();
 
   // clang-format off
   addTypes<LLVMVoidType,
@@ -2213,6 +2233,57 @@ LogicalResult LLVMDialect::verifyDataLayoutString(
 /// Verify LLVM dialect attributes.
 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
                                                     NamedAttribute attr) {
+  // If the `llvm.loop` attribute is present, enforce the following structure,
+  // which the module translation can assume.
+  if (attr.first.strref() == LLVMDialect::getLoopAttrName()) {
+    auto loopAttr = attr.second.dyn_cast<DictionaryAttr>();
+    if (!loopAttr)
+      return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName()
+                               << "' to be a dictionary attribute";
+    Optional<NamedAttribute> parallelAccessGroup =
+        loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
+    if (parallelAccessGroup.hasValue()) {
+      auto accessGroups = parallelAccessGroup->second.dyn_cast<ArrayAttr>();
+      if (!accessGroups)
+        return op->emitOpError()
+               << "expected '" << LLVMDialect::getParallelAccessAttrName()
+               << "' to be an array attribute";
+      for (Attribute attr : accessGroups) {
+        auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>();
+        if (!accessGroupRef)
+          return op->emitOpError()
+                 << "expected '" << attr << "' to be a symbol reference";
+        StringRef metadataName = accessGroupRef.getRootReference();
+        auto metadataOp =
+            SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+                op->getParentOp(), metadataName);
+        if (!metadataOp)
+          return op->emitOpError()
+                 << "expected '" << attr << "' to reference a metadata op";
+        StringRef accessGroupName = accessGroupRef.getLeafReference();
+        Operation *accessGroupOp =
+            SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
+        if (!accessGroupOp)
+          return op->emitOpError()
+                 << "expected '" << attr << "' to reference an access_group op";
+      }
+    }
+
+    Optional<NamedAttribute> loopOptions =
+        loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
+    if (loopOptions.hasValue()) {
+      auto options = loopOptions->second.dyn_cast<ArrayAttr>();
+      if (!options)
+        return op->emitOpError()
+               << "expected '" << LLVMDialect::getLoopOptionsAttrName()
+               << "' to be an array attribute";
+      if (!llvm::all_of(options, [](Attribute option) {
+            return option.isa<LoopOptionAttr>();
+          }))
+        return op->emitOpError() << "invalid loop options list " << options;
+    }
+  }
+
   // If the data layout attribute is present, it must use the LLVM data layout
   // syntax. Try parsing it and report errors in case of failure. Users of this
   // attribute may assume it is well-formed and can pass it to the (asserting)
@@ -2343,6 +2414,109 @@ Attribute FMFAttr::parse(DialectAsmParser &parser) {
   return FMFAttr::get(flags, parser.getBuilder().getContext());
 }
 
+LoopOptionAttr LoopOptionAttr::getDisableUnroll(MLIRContext *context,
+                                                bool disable) {
+  auto option = LoopOptionCase::disable_unroll;
+  return Base::get(context, static_cast<uint64_t>(option),
+                   static_cast<int32_t>(disable));
+}
+
+LoopOptionAttr LoopOptionAttr::getDisableLICM(MLIRContext *context,
+                                              bool disable) {
+  auto option = LoopOptionCase::disable_licm;
+  return Base::get(context, static_cast<uint64_t>(option),
+                   static_cast<int32_t>(disable));
+}
+
+LoopOptionAttr LoopOptionAttr::getInterleaveCount(MLIRContext *context,
+                                                  int32_t count) {
+  auto option = LoopOptionCase::interleave_count;
+  return Base::get(context, static_cast<uint64_t>(option),
+                   static_cast<int32_t>(count));
+}
+
+LoopOptionCase LoopOptionAttr::getCase() const {
+  return static_cast<LoopOptionCase>(getImpl()->option);
+}
+
+bool LoopOptionAttr::getBool() const {
+  LoopOptionCase option = getCase();
+  (void)option;
+  assert(option == LoopOptionCase::disable_licm ||
+         option == LoopOptionCase::disable_unroll &&
+             "expected a boolean loop option");
+  return static_cast<bool>(getImpl()->value);
+}
+
+int32_t LoopOptionAttr::getInt() const {
+  LoopOptionCase option = getCase();
+  (void)option;
+  assert(option == LoopOptionCase::interleave_count &&
+         "expected an integer loop option");
+  return getImpl()->value;
+}
+
+void LoopOptionAttr::print(DialectAsmPrinter &printer) const {
+  printer << "loopopt<" << stringifyEnum(getCase()) << " = ";
+  switch (getCase()) {
+  case LoopOptionCase::disable_licm:
+  case LoopOptionCase::disable_unroll:
+    printer << (getBool() ? "true" : "false");
+    break;
+  case LoopOptionCase::interleave_count:
+    printer << getInt();
+    break;
+  }
+  printer << ">";
+}
+
+Attribute LoopOptionAttr::parse(DialectAsmParser &parser) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  StringRef optionName;
+  if (failed(parser.parseKeyword(&optionName)))
+    return {};
+
+  auto option = symbolizeLoopOptionCase(optionName);
+  if (!option) {
+    parser.emitError(parser.getNameLoc(), "unknown loop option: ")
+        << optionName;
+    return {};
+  }
+
+  if (failed(parser.parseEqual()))
+    return {};
+
+  int32_t value;
+  switch (*option) {
+  case LoopOptionCase::disable_licm:
+  case LoopOptionCase::disable_unroll:
+    if (succeeded(parser.parseOptionalKeyword("true")))
+      value = 1;
+    else if (succeeded(parser.parseOptionalKeyword("false")))
+      value = 0;
+    else {
+      parser.emitError(parser.getNameLoc(),
+                       "expected boolean value 'true' or 'false'");
+      return {};
+    }
+    break;
+  case LoopOptionCase::interleave_count:
+    if (failed(parser.parseInteger(value))) {
+      parser.emitError(parser.getNameLoc(), "expected integer value");
+      return {};
+    }
+    break;
+  }
+
+  if (failed(parser.parseGreater()))
+    return {};
+
+  return Base::get(parser.getBuilder().getContext(),
+                   static_cast<uint64_t>(*option), value);
+}
+
 Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
                                       Type type) const {
   if (type) {
@@ -2356,14 +2530,18 @@ Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
   if (attrKind == "fastmath")
     return FMFAttr::parse(parser);
 
-  parser.emitError(parser.getNameLoc(), "Unknown attrribute type: ")
-      << attrKind;
+  if (attrKind == "loopopt")
+    return LoopOptionAttr::parse(parser);
+
+  parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
   return {};
 }
 
 void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
   if (auto fmf = attr.dyn_cast<FMFAttr>())
     fmf.print(os);
+  else if (auto lopt = attr.dyn_cast<LoopOptionAttr>())
+    lopt.print(os);
   else
     llvm_unreachable("Unknown attribute type");
 }

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index baf7107db80f..1f03092628a5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -168,6 +168,82 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
   return ret;
 }
 
+/// Returns an LLVM metadata node corresponding to a loop option. This metadata
+/// is attached to an llvm.loop node.
+static llvm::MDNode *getLoopOptionMetadata(llvm::LLVMContext &ctx,
+                                           LoopOptionAttr option) {
+  StringRef name;
+  llvm::Constant *value = nullptr;
+  switch (option.getCase()) {
+  case LoopOptionCase::disable_licm:
+    name = "llvm.licm.disable";
+    value = llvm::ConstantInt::getBool(ctx, option.getBool());
+    break;
+  case LoopOptionCase::disable_unroll:
+    name = "llvm.loop.unroll.disable";
+    value = llvm::ConstantInt::getBool(ctx, option.getBool());
+    break;
+  case LoopOptionCase::interleave_count:
+    name = "llvm.loop.interleave.count";
+    value = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, /*NumBits=*/32),
+                                   option.getInt());
+    break;
+  }
+  return llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
+                                 llvm::ConstantAsMetadata::get(value)});
+}
+
+static void setLoopMetadata(Operation &opInst, llvm::Instruction &llvmInst,
+                            llvm::IRBuilderBase &builder,
+                            LLVM::ModuleTranslation &moduleTranslation) {
+  if (Attribute attr = opInst.getAttr(LLVMDialect::getLoopAttrName())) {
+    llvm::Module *module = builder.GetInsertBlock()->getModule();
+    llvm::MDNode *loopMD = moduleTranslation.lookupLoopOptionsMetadata(attr);
+    if (!loopMD) {
+      llvm::LLVMContext &ctx = module->getContext();
+
+      SmallVector<llvm::Metadata *> loopOptions;
+      // Reserve operand 0 for loop id self reference.
+      auto dummy = llvm::MDNode::getTemporary(ctx, llvm::None);
+      loopOptions.push_back(dummy.get());
+
+      auto loopAttr = attr.cast<DictionaryAttr>();
+      auto parallelAccessGroup =
+          loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
+      if (parallelAccessGroup.hasValue()) {
+        SmallVector<llvm::Metadata *> parallelAccess;
+        parallelAccess.push_back(
+            llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
+        for (SymbolRefAttr accessGroupRef :
+             parallelAccessGroup->second.cast<ArrayAttr>()
+                 .getAsRange<SymbolRefAttr>())
+          parallelAccess.push_back(
+              moduleTranslation.getAccessGroup(opInst, accessGroupRef));
+        loopOptions.push_back(llvm::MDNode::get(ctx, parallelAccess));
+      }
+
+      auto loopOptionsAttr =
+          loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
+      if (loopOptionsAttr.hasValue()) {
+        for (LoopOptionAttr loopOption :
+             loopOptionsAttr->second.cast<ArrayAttr>()
+                 .getAsRange<LoopOptionAttr>())
+          loopOptions.push_back(getLoopOptionMetadata(ctx, loopOption));
+      }
+
+      // Create loop options and set the first operand to itself.
+      loopMD = llvm::MDNode::get(ctx, loopOptions);
+      loopMD->replaceOperandWith(0, loopMD);
+
+      // Store a map from this Attribute to the LLVM metadata in case we
+      // encounter it again.
+      moduleTranslation.mapLoopOptionsMetadata(attr, loopMD);
+    }
+
+    llvmInst.setMetadata(module->getMDKindID("llvm.loop"), loopMD);
+  }
+}
+
 static LogicalResult
 convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
@@ -295,6 +371,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     llvm::BranchInst *branch =
         builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
     moduleTranslation.mapBranch(&opInst, branch);
+    setLoopMetadata(opInst, *branch, builder, moduleTranslation);
     return success();
   }
   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
@@ -316,6 +393,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
         moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
         moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
     moduleTranslation.mapBranch(&opInst, branch);
+    setLoopMetadata(opInst, *branch, builder, moduleTranslation);
     return success();
   }
   if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index d52cc78a48fc..3a03b278e264 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -593,7 +593,7 @@ LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
 /// Check whether the module contains only supported ops directly in its body.
 static LogicalResult checkSupportedModuleOps(Operation *m) {
   for (Operation &o : getModuleBody(m).getOperations())
-    if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp>(&o) &&
+    if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::MetadataOp>(&o) &&
         !o.hasTrait<OpTrait::IsTerminator>())
       return o.emitOpError("unsupported module-level operation");
   return success();
@@ -633,6 +633,29 @@ LogicalResult ModuleTranslation::convertFunctions() {
   return success();
 }
 
+llvm::MDNode *
+ModuleTranslation::getAccessGroup(Operation &opInst,
+                                  SymbolRefAttr accessGroupRef) const {
+  auto metadataName = accessGroupRef.getRootReference();
+  auto accessGroupName = accessGroupRef.getLeafReference();
+  auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+      opInst.getParentOp(), metadataName);
+  auto *accessGroupOp =
+      SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
+  return accessGroupMetadataMapping.lookup(accessGroupOp);
+}
+
+LogicalResult ModuleTranslation::createAccessGroupMetadata() {
+  mlirModule->walk([&](LLVM::MetadataOp metadatas) {
+    metadatas.walk([&](LLVM::AccessGroupMetadataOp op) {
+      llvm::LLVMContext &ctx = llvmModule->getContext();
+      llvm::MDNode *accessGroup = llvm::MDNode::getDistinct(ctx, {});
+      accessGroupMetadataMapping.insert({op, accessGroup});
+    });
+  });
+  return success();
+}
+
 llvm::Type *ModuleTranslation::convertType(Type type) {
   return typeTranslator.translateType(type);
 }
@@ -697,6 +720,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
     return nullptr;
   if (failed(translator.convertGlobals()))
     return nullptr;
+  if (failed(translator.createAccessGroupMetadata()))
+    return nullptr;
   if (failed(translator.convertFunctions()))
     return nullptr;
   if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index b7984805b374..6a45b1f67e71 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -713,3 +713,86 @@ llvm.mlir.global common @non_zero_compound_global_common_linkage(dense<[0, 0, 0,
 
 // expected-error at below {{expected array type for 'appending' linkage}}
 llvm.mlir.global appending @non_array_type_global_appending_linkage() : i32
+
+// -----
+
+module {
+  llvm.func @loopOptions() {
+      // expected-error at below {{expected 'llvm.loop' to be a dictionary attribute}}
+      llvm.br ^bb4 {llvm.loop = "test"}
+    ^bb4:
+      llvm.return
+  }
+}
+
+// -----
+
+module {
+  llvm.func @loopOptions() {
+      // expected-error at below {{expected 'parallel_access' to be an array attribute}}
+      llvm.br ^bb4 {llvm.loop = {parallel_access = "loop"}}
+    ^bb4:
+      llvm.return
+  }
+}
+
+// -----
+
+module {
+  llvm.func @loopOptions() {
+      // expected-error at below {{expected '"loop"' to be a symbol reference}}
+      llvm.br ^bb4 {llvm.loop = {parallel_access = ["loop"]}}
+    ^bb4:
+      llvm.return
+  }
+}
+
+// -----
+
+module {
+  llvm.func @loopOptions() {
+      // expected-error at below {{expected '@func1' to reference a metadata op}}
+      llvm.br ^bb4 {llvm.loop = {parallel_access = [@func1]}}
+    ^bb4:
+      llvm.return
+  }
+  llvm.func @func1() {
+    llvm.return
+  }
+}
+
+// -----
+
+module {
+  llvm.func @loopOptions() {
+      // expected-error at below {{expected '@metadata' to reference an access_group op}}
+      llvm.br ^bb4 {llvm.loop = {parallel_access = [@metadata]}}
+    ^bb4:
+      llvm.return
+  }
+  llvm.metadata @metadata {
+    llvm.return
+  }
+}
+
+// -----
+
+module {
+  llvm.func @loopOptions() {
+      // expected-error at below {{expected 'options' to be an array attribute}}
+      llvm.br ^bb4 {llvm.loop = {options = "name"}}
+    ^bb4:
+      llvm.return
+  }
+}
+
+// -----
+
+module {
+  llvm.func @loopOptions() {
+      // expected-error at below {{invalid loop options list}}
+      llvm.br ^bb4 {llvm.loop = {options = ["name"]}}
+    ^bb4:
+      llvm.return
+  }
+}

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 27c5783a4d8d..b202432ccbae 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -419,3 +419,12 @@ func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32) {
   %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32
   return
 }
+
+module {
+  // CHECK: llvm.metadata @metadata attributes {test_attribute} {
+  llvm.metadata @metadata attributes {test_attribute} {
+    // CHECK: llvm.access_group @group1
+    llvm.access_group @group1
+    llvm.return
+  }
+}

diff  --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index abe2deb685bb..1109345231f2 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1469,3 +1469,38 @@ llvm.func @switch_weights(%arg0: i32) -> i32 {
 }
 
 // CHECK: ![[SWITCH_WEIGHT_NODE]] = !{!"branch_weights", i32 13, i32 17, i32 19}
+
+// -----
+
+module {
+  llvm.func @loopOptions(%arg1 : i32, %arg2 : i32) {
+      %0 = llvm.mlir.constant(0 : i32) : i32
+      %4 = llvm.alloca %arg1 x i32 : (i32) -> (!llvm.ptr<i32>)
+      llvm.br ^bb3(%0 : i32)
+    ^bb3(%1: i32):
+      %2 = llvm.icmp "slt" %1, %arg1 : i32
+      // CHECK: br i1 {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]]
+      llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
+    ^bb4:
+      %3 = llvm.add %1, %arg2  : i32
+      %5 = llvm.load %4 { access_groups = [@metadata::@group1, @metadata::@group2] } : !llvm.ptr<i32>
+      // CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]]
+      llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
+    ^bb5:
+      llvm.return
+  }
+
+  llvm.metadata @metadata {
+    llvm.access_group @group1
+    llvm.access_group @group2
+    llvm.return
+  }
+}
+
+// CHECK: ![[LOOP_NODE]] = distinct !{![[LOOP_NODE]], ![[PA_NODE:[0-9]+]], ![[UNROLL_DISABLE_NODE:[0-9]+]], ![[LICM_DISABLE_NODE:[0-9]+]], ![[INTERLEAVE_NODE:[0-9]+]]}
+// CHECK: ![[PA_NODE]] = !{!"llvm.loop.parallel_accesses", ![[GROUP_NODE1:[0-9]+]], ![[GROUP_NODE2:[0-9]+]]}
+// CHECK: ![[GROUP_NODE1]] = distinct !{}
+// CHECK: ![[GROUP_NODE2]] = distinct !{}
+// CHECK: ![[UNROLL_DISABLE_NODE]] = !{!"llvm.loop.unroll.disable", i1 true}
+// CHECK: ![[LICM_DISABLE_NODE]] = !{!"llvm.licm.disable", i1 true}
+// CHECK: ![[INTERLEAVE_NODE]] = !{!"llvm.loop.interleave.count", i32 1}


        


More information about the Mlir-commits mailing list